题解:P15578 [USACO26FEB] Random Tree Generation G

· · 题解

树形 DP + 快速幂求逆元

本题解中将生成随机带标号树的两步中第一步产生的树称为原树,将输入的树称为新树。

首先考虑在新树中任意指定一个点 u 为根,即钦定节点 u 为原树中的节点 1,然后考虑给所有新树中的 i 号点分配分配一个原树中的节点编号 q_i。那么 q 序列要满足下面两条性质才是正确的:

于是我们考虑计算任意一个 q 序列满足条件的概率。对于每个点 uq 序列满足性质 2 的概率应该是 \frac{1}{siz_u},且对于所有点都要满足这一条件,因此整体的概率为 \prod_{u} \frac{1}{siz_u},再乘上 1n 的排列总数 n!,以及累加指定的根不同时的结果,可得满足条件的序列 q 数量为 n! \sum_{rt}{{\prod_{u} \frac{1}{siz_{rt,u}}}}。其中 siz{rt,u} 表示指定新树的根为 rtu 子树的大小。

因为第一步生成的树的总方案数显然为 (n-1)!,因此第一步生成符合条件的原树的概率就是 \frac{\sum_{rt}{n! {\prod_{u} \frac{1}{siz_{rt,u}}}}}{(n-1)!}

第二步是简单的,只需要保证重新分配的编号与新树编号一致即可,显然概率为 \frac{1}{n!}

于是最终生成新树 T 的概率为:P(T)=\frac{n! \sum_{rt}{{\prod_{u} \frac{1}{siz_{rt,u}}}}}{(n-1)!n!}=\frac{ \sum_{rt}{{\prod_{u} \frac{1}{siz_{rt,u}}}}}{(n-1)!}

注意到分母部分是容易计算的,而分子上连乘的一部分并不能简单地计算。

于是考虑用树形 DP 维护,设 f_j=\prod_{u} {siz_{j,u}},最初以 1 为根,先用一次 DFS 计算出 f_1

如下图所示,当前 DP 到 u,下一步要 DP v 点时,黄色点子树的大小是不变的,只有 u,v 两个点的子树大小发生了变化,u 点的 sizn 变成了 n-siz_{1,v}v 点的 sizsiz_{1,v} 变成了 n

于是就可以顺利进行 DP 了,对于每个 u 的子节点 vf_v=f_u \times \frac{n-siz{1,v}}{n} \times \frac{n}{siz_{1,v}}=f_u \times \frac{n-siz_{1,v}}{siz{1,v}},用快速幂求 siz_{1,v} 的逆元即可。

代码:

#include <bits/stdc++.h>

#define int long long

using namespace std;

const int N=2e5+10,mod=1e9+7;
int n,f[N],siz[N];
vector<int> g[N];

int fpow(int a,int b)
{
    int res=1;
    a%=mod;
    while(b)
    {
        if(b&1) res=(res*a)%mod;
        a=(a*a)%mod;
        b>>=1;
    }
    return res;
}

int inv(int x)
{
    return fpow(x,mod-2);
}

void dfs(int u,int fa)
{
    siz[u]=1;
    for(int v:g[u])
    {
        if(v==fa) continue;
        dfs(v,u);
        siz[u]+=siz[v];
    }
}

void dp(int u,int fa)
{
    for(int v:g[u])
    {
        if(v==fa) continue;
        f[v]=f[u]*(n-siz[v])%mod*inv(siz[v])%mod;
        dp(v,u);
    }
}

void solve()
{
    cin>>n;
    for(int i=1;i<=n;i++) g[i].clear();
    for(int i=1;i<n;i++)
    {
        int u,v;
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    memset(siz,0,sizeof siz);
    memset(f,0,sizeof f);
    dfs(1,0);
    f[1]=1;
    for(int i=1;i<=n;i++) f[1]=(f[1]*siz[i])%mod;
    dp(1,0);
    int ans=0,fac=1;
    for(int i=1;i<=n-1;i++) fac=(fac*i)%mod;
    ans=inv(fac);
    int res=0;
    for(int i=1;i<=n;i++) res=(res+inv(f[i]))%mod;
    ans=(ans*res)%mod;
    cout<<ans<<"\n";
}
signed main()
{
    int T;
    cin>>T;
    while(T--) solve();
}