P8334 题解

· · 题解

题解好像大部分是差分拆贡献,这是 SD 二轮省集 Harry27182 老师的做法,感觉很牛,在此记录。

注:以下用 w(x,y) 代表原题中的 f(x,y)mn_u 表示 u 子树内 a 的最小值。

显然可以先把 a 离散化,之后考虑对最终答案相同的方案整体计算贡献。设 f_{u,i} 表示对于所有合法的 w(u,x),它们中值为 i 的期望个数。转移时讨论子树内路径最小值是否仍为最小,乘上组合数转移。

具体地,对于 u 的儿子 v,若以 v 为起点时的最小值 i 仍最小,需要 u 其他子树中 mn_x\le i 的所有 x 在搜索时均排在 v 之后。此处计算方案数时,可设 u 共有 S 个儿子,则先拿出 mn_x\le ic 个(其中必然包含 v),要求 v 在开头,其他随意排列,剩余的 (S-c) 个也随意排列。之后将 S 个位置分类并分别放入,即为合法排列数。因此概率为

\frac{{S\choose c}\times (c-1)!\times (S-c)!}{S!}=\frac{S!}{c!\times (S-c)!}\times (c-1)!\times (S-c)!\times \frac 1{S!}=\frac 1 c.

转移即为 f_{u,i}\leftarrow f_{v,i}\times\frac 1 {\sum_{x\in son(u)}[mn_x\le i]}

另一种情况是最小值变成了另一子树 x 的最小值 mn_x,最终结束于 v 子树内的终点,这里显然有限制 mn_x\le i。此时先将 u 的所有子树按 mn 从小到大排序,设 rk_x 为排序后 x 的排名。则所有排在 x 前面的子树中,除 v 外的其他子树必须排在 v 之后。

那么需要分 vx 之前和之后讨论,以确定需要在 v 之后的子树个数。这里以 vx 之前为例,仿照上面可以把 rk_x 个子树拿出来单独排列,并将 x,v 分别放到前两位,概率即为

\frac{{S\choose rk_x}\times (rk_x-2)!\times (S-rk_x)!}{S!}=\frac{S!}{rk_x!\times (S-rk_x)!}\times (rk_x-2)!\times (S-rk_x)!\times \frac 1{S!}=\frac 1 {rk_x(rk_x-1)}.

同理可得 vx 之后时系数为 \frac 1{rk_x(rk_x+1)}。转移时为降低复杂度,可以枚举 x,从而用前缀和优化省去对 v,i 的枚举,即

f_{u,mn_x}\leftarrow\frac 1 {rk_x(rk_x-1)}\sum_{rk_v<rk_x}\sum_{i\ge mn_x} f_{v,i}+\frac 1{rk_x(rk_x+1)}\sum_{rk_v>rk_x}\sum_{i\ge mn_x} f_{v,i}.

这两种转移完成后,由于从 u 出发必然经过 a_u,需要把所有 f_{u,i}ia_u 取 min,即将大于 a_u 的 DP 值均加到 f_{u,a_u} 上并清空。最后给答案加上所有 w(u,x) 的贡献,即 \sum i\times f_{u,i}

另外注意到会有相等的 a 值,这时钦定 f_{v,i} 中的 i 为等大的数中最大的,其余相等的 mn_xrk 在前的较小,这样定义后整个 DP 过程即上述,可以实现不重不漏。目前时间复杂度为 O(n^2)

考虑优化 DP 过程,注意到第一种转移是对应位置累加,且每个位置上转移系数相等,可以使用线段树合并。同时由于转移系数只与不超过 imn_x 个数有关,不同的区间只有 O(deg_u) 个,可以用区间乘解决。

进行第二种转移时,可以开一棵临时的线段树,通过合并得到 rk 数组上前后缀的线段树,再进行 i\ge mn_x 的区间查询,最后进行单点加即可。注意两者的系数不同,需要分别顺序和逆序做。最后对 a_u 取 min 只需区间查询,清空即为区间乘 0,也是区间乘。

所以需要实现线段树合并,并支持区间乘,单点加,区间求和,这些操作均不难实现,时空复杂度 O(n\log n)。由于有区间乘和临时的前后缀线段树,最终空间大概需要 4 倍的 n\log n3\times 10^7 就足够了。

附上代码:

#include<iostream>
#include<vector>
#include<algorithm>
#define pb push_back
#define mid ((l+r)>>1)
using namespace std;
const int N=4e5+10;
const int P=4e5;
const int M=3e7+10;
const int mod=998244353;
void add(int &a,int b) {a+=b;if(a>=mod)a-=mod;}
int n,m,rot,res,a[N],b[N],x[N],mn[N],rt[N],inv[N];
vector <int> e[N];
bool cmp(int i,int j) {return mn[i]<mn[j];}
struct sgmtt
{
    int t,lc[M],rc[M],w[M],k[M],tag[M];
    void cle(int u) {lc[u]=rc[u]=w[u]=k[u]=0,tag[u]=1;} 
    void pushup(int u) {w[u]=w[lc[u]],k[u]=k[lc[u]],add(w[u],w[rc[u]]),add(k[u],k[rc[u]]);}
    void pt(int u,int x) {w[u]=1ll*w[u]*x%mod,k[u]=1ll*k[u]*x%mod,tag[u]=1ll*tag[u]*x%mod;}
    void pushdown(int u) {if(tag[u]!=1) pt(lc[u],tag[u]),pt(rc[u],tag[u]),tag[u]=1;}
    void update(int u,int l,int r,int L,int R,int x)
    {
        if(!u||L>R) return;
        if(l>=L&&r<=R) {pt(u,x); return;}
        pushdown(u);
        if(L<=mid) update(lc[u],l,mid,L,R,x);
        if(R>mid) update(rc[u],mid+1,r,L,R,x);
        pushup(u);
    }
    void change(int &u,int l,int r,int p,int x)
    {
        if(!u) u=++t,cle(t);
        if(l==r) {add(k[u],x),w[u]=1ll*k[u]*b[l]%mod; return;}
        pushdown(u);
        if(p<=mid) change(lc[u],l,mid,p,x);
        else change(rc[u],mid+1,r,p,x);
        pushup(u);
    }
    int query(int u,int l,int r,int L,int R)
    {
        if(!u||L>R) return 0;
        if(l>=L&&r<=R) return k[u];
        pushdown(u); int tr=0;
        if(L<=mid) add(tr,query(lc[u],l,mid,L,R));
        if(R>mid) add(tr,query(rc[u],mid+1,r,L,R));
        return tr;
    }
    int merg(int u,int v,int l,int r)
    {
        if(!u||!v) return u+v;
        int p=++t; cle(t);
        if(l==r) w[p]=w[u],k[p]=k[u],add(w[p],w[v]),add(k[p],k[v]);
        else pushdown(u),pushdown(v),lc[p]=merg(lc[u],lc[v],l,mid),rc[p]=merg(rc[u],rc[v],mid+1,r);
        if(l<r) pushup(p);
        return p;
    }
}T;
void dfs(int u,int fat)
{
    mn[u]=a[u]; vector <int> p;
    for(int v:e[u]) if(v!=fat)
    {
        dfs(v,u),p.pb(v);
        mn[u]=min(mn[u],mn[v]);
    }
    sort(p.begin(),p.end(),cmp);
    int s=p.size(),cur=0;
    for(int i=0;i<s;i++)
    {
        x[i]=1ll*T.query(cur,1,m,mn[p[i]],m)*inv[i]%mod*inv[i+1]%mod;
        cur=T.merg(cur,rt[p[i]],1,m);
    }
    cur=0;
    for(int i=s-1;~i;i--)
    {
        add(x[i],1ll*T.query(cur,1,m,mn[p[i]],m)*inv[i+1]%mod*inv[i+2]%mod);
        cur=T.merg(cur,rt[p[i]],1,m),rt[u]=T.merg(rt[u],rt[p[i]],1,m);
    }
    for(int i=1;i<s;i++) if(mn[p[i]]!=mn[p[i-1]]) T.update(rt[u],1,m,mn[p[i-1]],mn[p[i]]-1,inv[i]);
    if(s) T.update(rt[u],1,m,mn[p[s-1]],m,inv[s]);
    for(int i=0;i<s;i++) T.change(rt[u],1,m,mn[p[i]],x[i]);
    T.change(rt[u],1,m,a[u],T.query(rt[u],1,m,a[u]+1,m)+1),T.update(rt[u],1,m,a[u]+1,m,0),add(res,T.w[rt[u]]);
}
void sol()
{
    cin>>n>>rot,res=T.t=0;
    for(int i=1;i<=n;i++) cin>>a[i],b[i]=a[i],rt[i]=0,e[i].clear();
    sort(b+1,b+1+n),m=unique(b+1,b+1+n)-b-1;
    for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b; 
    for(int i=1,u,v;i<n;i++) cin>>u>>v,e[u].pb(v),e[v].pb(u);
    dfs(rot,0),cout<<res<<'\n';
}
int main()
{
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    inv[0]=inv[1]=1;
    for(int i=2;i<=P;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    int TT; cin>>TT;
    while(TT--) sol();
    return 0;
}