题解:P15578 [USACO26FEB] Random Tree Generation G

· · 题解

闲话

第一次打金组就成功晋级!场切蓝题特此纪念。

非常好计数题。

题目大意

给定一棵 N 个节点的树,它由以下两步随机过程生成:

  1. 从节点 1 开始,对 i=2,\dots,N,将 i\mathrm{randint}(1,i-1) 连边(每个 i 独立均匀选择父节点)。这样生成一棵父节点编号小于子节点的树,共有 (N-1)! 种等可能结果。
  2. 随机均匀打乱所有节点的编号(即随机排列 p_1,\dots,p_N,将原编号 v 改为 p_v),共有 N! 种等可能结果。

现给定最终树的边集,求生成这棵树的概率,对 10^9+7 取模。

解题思路

T 为给定树。考虑第一步生成的树 S 和第二步的排列 \pi,满足 \pi(S)=T,即 S=\pi^{-1}(T)S 满足题意当且仅当存在一个根 r,使得整棵树满足每个节点标号大于其父节点。

对于以固定节点 r 为根的有根树,考虑随机排列所有 N 个节点的标号。对于任意节点 v,它在其子树中是编号最小的概率为 \frac{1}{sz_r(v)},其中 sz_r(v) 是该子树的大小。这些事件彼此独立,因此所有节点同时满足该条件的概率为

\prod_{v=1}^N \frac{1}{sz_r(v)}

而整棵树满足性质等价于每个节点都是其子树中最小的,故符合条件的排列数为

N! \times \prod_{v=1}^N \frac{1}{sz_r(v)} = \frac{N!}{\prod_{v} sz_r(v)}

其中 sz_r(v) 是以 r 为根时节点 v 的子树大小。因此,所有可能的 (\pi,r) 总数为

\sum_{r=1}^N \frac{N!}{\prod_v sz_r(v)}

故生成树 T 的概率为

=\frac{1}{(N-1)!}\sum_{r=1}^N\frac{1}{\prod_v sz_r(v)}

dp[u]=\prod_v sz_u(v),则答案即为

\frac{1}{(N-1)!}\sum_{u=1}^N dp[u]^{-1}\pmod{10^9+7}

考虑换根 DP 求所有 dp[u]。首先以 1 为根进行 DFS,得到每个节点的子树大小 sz[u](以 1 为根)。此时

dp[1]=\prod_{v=1}^N sz[v]

考虑从父亲 fa 转移到儿子 u。当根从 fa 变为 u 时,只有节点 fau 的子树大小发生变化,其他节点不变:

因此:

dp[u]=dp[fa]\times\frac{N}{sz[u]}\times\frac{N-sz[u]}{N}=dp[fa]\times\frac{N-sz[u]}{sz[u]}

通过第二次 DFS 即可 O(n) 求出所有 dp[u]

求出所有 dp[u] 后,计算

S=\sum_{u=1}^N dp[u]^{-1}\bmod M,\quad \text{ans}=S\times ((N-1)!)^{-1}\bmod M

其中 M=10^9+7 是质数,用快速幂求逆元。

复杂度分析

代码实现

#include <bits/stdc++.h>
using namespace std;

#define ll long long

const ll N = 1e6 + 5, mod = 1e9 + 7;
ll tc, n, fac[N], invf[N], inv[N];
ll sz[N], fa[N], dp[N];
vector<ll> G[N];

ll qpow(ll x, ll y) {
    if (y == 0) { return 1; }
    ll res = qpow(x, y / 2);
    res = res * res % mod;
    if (y % 2) {
        res = res * x % mod;
    }
    return res;
}

void init() {
    for (ll i = 1; i <= n; i++) {
        G[i].clear();
    }
    fac[0] = 1;
    for (ll i = 1; i <= n; i++) {
        fac[i] = fac[i - 1] * i % mod;
    }
}

void dfs1(ll u, ll ff = 0) {
    fa[u] = ff;
    sz[u] = 1;
    for (ll v : G[u]) {
        if (v == ff) { continue; }
        dfs1(v, u);
        sz[u] += sz[v];
    }
}

void dfs2(ll u, ll ff = 0) {
    if (u != 1) {
        dp[u] = dp[ff] * qpow(sz[u], mod - 2) % mod * (n - sz[u]) % mod;
    }
    for (ll v : G[u]) {
        if (v == ff) { continue; }
        dfs2(v, u);
    }
}

void solve() {
    cin >> n;
    init();
    for (ll i = 1; i < n; i++) {
        ll u, v;
        cin >> u >> v;
        G[u].push_back(v), G[v].push_back(u);
    }
    dfs1(1ll);
    dp[1] = 1;
    for (ll i = 1; i <= n; i++) {
        (dp[1] *= sz[i]) %= mod;
    }
    dfs2(1ll);
    ll sum = 0;
    for (ll i = 1; i <= n; i++) {
        (sum += qpow(dp[i], mod - 2)) %= mod;
    }
    cout << sum * qpow(fac[n - 1], mod - 2) % mod << "\n";
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> tc;
    while (tc--) { solve(); }
    return 0;
}