P8334 题解
题解好像大部分是差分拆贡献,这是 SD 二轮省集 Harry27182 老师的做法,感觉很牛,在此记录。
注:以下用
显然可以先把
具体地,对于
转移即为
另一种情况是最小值变成了另一子树
那么需要分
同理可得
这两种转移完成后,由于从
另外注意到会有相等的
考虑优化 DP 过程,注意到第一种转移是对应位置累加,且每个位置上转移系数相等,可以使用线段树合并。同时由于转移系数只与不超过
进行第二种转移时,可以开一棵临时的线段树,通过合并得到
所以需要实现线段树合并,并支持区间乘,单点加,区间求和,这些操作均不难实现,时空复杂度
附上代码:
#include<iostream>
#include<vector>
#include<algorithm>
#define pb push_back
#define mid ((l+r)>>1)
using namespace std;
const int N=4e5+10;
const int P=4e5;
const int M=3e7+10;
const int mod=998244353;
void add(int &a,int b) {a+=b;if(a>=mod)a-=mod;}
int n,m,rot,res,a[N],b[N],x[N],mn[N],rt[N],inv[N];
vector <int> e[N];
bool cmp(int i,int j) {return mn[i]<mn[j];}
struct sgmtt
{
int t,lc[M],rc[M],w[M],k[M],tag[M];
void cle(int u) {lc[u]=rc[u]=w[u]=k[u]=0,tag[u]=1;}
void pushup(int u) {w[u]=w[lc[u]],k[u]=k[lc[u]],add(w[u],w[rc[u]]),add(k[u],k[rc[u]]);}
void pt(int u,int x) {w[u]=1ll*w[u]*x%mod,k[u]=1ll*k[u]*x%mod,tag[u]=1ll*tag[u]*x%mod;}
void pushdown(int u) {if(tag[u]!=1) pt(lc[u],tag[u]),pt(rc[u],tag[u]),tag[u]=1;}
void update(int u,int l,int r,int L,int R,int x)
{
if(!u||L>R) return;
if(l>=L&&r<=R) {pt(u,x); return;}
pushdown(u);
if(L<=mid) update(lc[u],l,mid,L,R,x);
if(R>mid) update(rc[u],mid+1,r,L,R,x);
pushup(u);
}
void change(int &u,int l,int r,int p,int x)
{
if(!u) u=++t,cle(t);
if(l==r) {add(k[u],x),w[u]=1ll*k[u]*b[l]%mod; return;}
pushdown(u);
if(p<=mid) change(lc[u],l,mid,p,x);
else change(rc[u],mid+1,r,p,x);
pushup(u);
}
int query(int u,int l,int r,int L,int R)
{
if(!u||L>R) return 0;
if(l>=L&&r<=R) return k[u];
pushdown(u); int tr=0;
if(L<=mid) add(tr,query(lc[u],l,mid,L,R));
if(R>mid) add(tr,query(rc[u],mid+1,r,L,R));
return tr;
}
int merg(int u,int v,int l,int r)
{
if(!u||!v) return u+v;
int p=++t; cle(t);
if(l==r) w[p]=w[u],k[p]=k[u],add(w[p],w[v]),add(k[p],k[v]);
else pushdown(u),pushdown(v),lc[p]=merg(lc[u],lc[v],l,mid),rc[p]=merg(rc[u],rc[v],mid+1,r);
if(l<r) pushup(p);
return p;
}
}T;
void dfs(int u,int fat)
{
mn[u]=a[u]; vector <int> p;
for(int v:e[u]) if(v!=fat)
{
dfs(v,u),p.pb(v);
mn[u]=min(mn[u],mn[v]);
}
sort(p.begin(),p.end(),cmp);
int s=p.size(),cur=0;
for(int i=0;i<s;i++)
{
x[i]=1ll*T.query(cur,1,m,mn[p[i]],m)*inv[i]%mod*inv[i+1]%mod;
cur=T.merg(cur,rt[p[i]],1,m);
}
cur=0;
for(int i=s-1;~i;i--)
{
add(x[i],1ll*T.query(cur,1,m,mn[p[i]],m)*inv[i+1]%mod*inv[i+2]%mod);
cur=T.merg(cur,rt[p[i]],1,m),rt[u]=T.merg(rt[u],rt[p[i]],1,m);
}
for(int i=1;i<s;i++) if(mn[p[i]]!=mn[p[i-1]]) T.update(rt[u],1,m,mn[p[i-1]],mn[p[i]]-1,inv[i]);
if(s) T.update(rt[u],1,m,mn[p[s-1]],m,inv[s]);
for(int i=0;i<s;i++) T.change(rt[u],1,m,mn[p[i]],x[i]);
T.change(rt[u],1,m,a[u],T.query(rt[u],1,m,a[u]+1,m)+1),T.update(rt[u],1,m,a[u]+1,m,0),add(res,T.w[rt[u]]);
}
void sol()
{
cin>>n>>rot,res=T.t=0;
for(int i=1;i<=n;i++) cin>>a[i],b[i]=a[i],rt[i]=0,e[i].clear();
sort(b+1,b+1+n),m=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b;
for(int i=1,u,v;i<n;i++) cin>>u>>v,e[u].pb(v),e[v].pb(u);
dfs(rot,0),cout<<res<<'\n';
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
inv[0]=inv[1]=1;
for(int i=2;i<=P;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
int TT; cin>>TT;
while(TT--) sol();
return 0;
}