题解:AT_agc073_c [AGC073C] Product of Max of Sum of Subtree

· · 题解

题目大意

有一棵 n 个点的树,考虑给每个点均匀随机一个在 [-(n-1),1] 之间的实数。

对于每个点 x,定义 p_x 为树上包含点 x 的联通块的最大权值和。

定义这棵树的贡献为,若 \forall i,0\leqslant p_x\leqslant1,则贡献为 \prod_{i=1}^np_i,否则贡献为 0

求贡献的期望。

解法

为了区分取值范围与点数,设取值范围为 len,易知 len=n

先考虑一个弱化问题:若整棵树的 p 值相等,求贡献的期望。

不妨设整棵树的 p 值为 f,易知 f\in[0,1] 时贡献非零。

s_x 表示 x 的子树中所有随机权值的和,由于若 s_x<0,则抛弃掉这颗子树显然更优,而又由于 s_x\leqslant f,所以 s_x 的取值范围为 s_x\in[0,f]

因为一种随机方案与一种 s_x 的方案唯一对应,由于要求 s_1=f,所以序列 s 合法的概率为 f^{n-1}\over len^n,所以树的期望贡献为:

\int_0^1{x^{n-1}\over len^n}x^n

简单化简一下:

\begin{aligned} &\int_0^1{x^{n-1}\over len^n}x^n\\ =&{1\over len^n}\int_0^1x^{n-1}x^n\\ =&{1\over len^n}\int_0^1x^{2n-1}\\ =&{1\over len^n}\left({1^{2n}\over2n}-{0^{2n}\over 2n}\right)\\ =&{1\over len^n}\times{1\over2n} \end{aligned}

再考虑原问题。

容易发现,我们可以把树分为许多值相同的联通块,对于一个联通块可以用上述方法去做,且可以做到联通块之间互不影响。

若两点 xy 属于不同联通块,不妨设 x 所在联通块的 f_x 大于 y 所在联通块的 f_y,那么我们给 y 随机的权值减去 f_x 就行了,这样联通块就能互不重合了,f_x\leqslant f_y 同理。因为最多减掉 n-1,所以不会超出随机值下界。

然后就可以在树上 dp 了。

f_{x,i} 表示 x 的子树中 x 所在的联通块大小,i=0 表示该联通块已计算完贡献。

转移是平凡的,由于所有联通块大小的和一定为 n,于是 dp 完后再乘上 1\over len^n 就行了。

时间复杂度 \mathcal{O}(n^2)

::::info[Code]

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e3+10;
const int p=998244353;
int n,m,inv[100010];
int f[N][N],sz[N],h[N];
vector<int> g[N];
int mo(int x){return x>=p?x-p:x;}
void add(int&x,int y) {x=mo(x+y);}
int ksm(int a,int b){
    int ans=1;
    for(;b;b>>=1,a=a*a%p) if(b&1) ans=ans*a%p;
    return ans;
}
void dfs(int x,int fa){
    sz[x]=1;
    f[x][1]=1;
    for(int v:g[x]){
        if(v==fa) continue;
        dfs(v,x);
        for(int i=1;i<=sz[x];i++)
            for(int j=0;j<=sz[v];j++)
                add(h[i+j],f[x][i]*f[v][j]%p);
        sz[x]+=sz[v];
        for(int i=1;i<=sz[x];i++) f[x][i]=h[i],h[i]=0;
    }
    for(int i=1;i<=sz[x];i++) add(f[x][0],f[x][i]*inv[i*2]%p);
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n;
    inv[1]=1;
    for(int i=2;i<=n*2;i++) inv[i]=(p-p/i)*inv[p%i]%p;
    for(int i=1,x,y;i<n;i++) cin>>x>>y,g[x].push_back(y),g[y].push_back(x);
    dfs(1,0);
    cout<<f[1][0]*ksm(ksm(n,n),p-2)%p<<"\n";
    return 0;
}

::::