题解:P15649 [省选联考 2026] 找寻者 / recollector
首先可以注意到,期望的和就是和的期望,因此不妨从每个点到祖先上轻边数量的总和考虑。
发现带着轻边数量难以设计优秀的状态,不妨考虑交换贡献形式,可以发现每条轻边仅会在轻边父子中儿子对应的子树产生贡献,因此可以将“每个点到祖先的轻边数量和”变成“所有轻边中儿子的子树大小和”。下面为了讨论方便,将统计目标变成了“所有重边中儿子的子树大小和”,因为一个点只会往下延伸一条重边,统计起来更方便一些。
期望值可以在每个点上统计,只需要枚举这个点的重儿子,算出取到这个重儿子对应的概率,与重儿子子树大小相乘后累加即可。因此问题进一步转化成确认每个儿子作为重儿子的概率。这样就终于可以只在概率而非期望的角度考虑这个问题了。
设
对于
::::info[多项式除法]{open} 不妨先想想,竖式除法是怎么确定结果的每一位的?
假设需要将长度为
结果多项式为
在计算除法的结果后,直接枚举
多项式除法的计算量和上述转移对应是相同的,因此只需要考虑上述转移的总时间复杂度。注意到
最后,考虑和式
这个和式对于每个点
::::info[代码]
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
using namespace std;
using i64 = long long;
const int mod = 998244353;
i64 pw(int x, int y) {
i64 res = 1, bas = x;
while (y) {
if (y & 1)
res = res * bas % mod;
bas = bas * bas % mod;
y >>= 1;
}
return res;
}
int t, n;
vector<int> g[5010];
i64 f[5010][5010], tmp[5010], tmp2[5010];
i64 ans = 0;
int sz[5010];
int mx[5010];
i64 inv[5010];
void dfs1(int x, int fa) {
sz[x] = mx[x] = 1;
for (auto v : g[x]) if (v != fa) {
dfs1(v, x);
sz[x] += sz[v];
mx[x] = max(mx[x], mx[v] + 1);
}
}
void dfs2(int x, int fa) {
if (sz[x] == 1) {
f[x][0] = 0;
f[x][1] = 1;
return;
}
// 第一部分:合并子树多项式
int len = 0;
f[x][0] = 1;
for (auto v : g[x]) if (v != fa) {
dfs2(v, x);
int tl = mx[v];
for (int i = 0; i <= len + tl; i ++)
tmp[i] = 0;
for (int a = 0; a <= len; a ++)
for (int b = 0; b <= tl; b ++)
(tmp[a + b] += f[x][a] * f[v][b]) %= mod;
for (int i = 0; i <= len + tl; i ++)
f[x][i] = tmp[i];
len += tl;
}
// 第二部分:模拟多项式除法,并写入结果
for (int i = 0; i <= mx[x]; i ++)
f[x][i] = 0;
for (auto v : g[x]) if (v != fa) {
i64 expe = 0;
int ptr = 0;
while (f[v][ptr] == 0)
++ ptr;
for (int i = ptr; i <= len; i ++)
tmp2[i] = tmp[i];
int len2 = mx[v];
i64 iv = pw(f[v][ptr], mod - 2);
// tmp2[ptr..len]
// f[v][ptr..len2]
for (int i = ptr; i <= len - (len2 - ptr); i ++) {
// 此时其他点的重链长度为 i - ptr
i64 coef = tmp2[i] * iv % mod;
for (int j = 0; j <= len2 - ptr; j ++) {
// 此时 v 的重链长度为 ptr + j
i64 contri = coef * f[v][ptr + j] % mod;
tmp2[i + j] -= contri;
if (tmp2[i + j] < 0)
tmp2[i + j] += mod;
contri = (ptr + j) * inv[i + j] % mod * contri % mod;
(f[x][ptr + j + 1] += contri) %= mod;
(expe += contri) %= mod;
}
}
(ans += sz[v] * expe) %= mod;
}
}
int main() {
for (int i = 1; i <= 5000; i ++)
inv[i] = pw(i, mod - 2);
int _;
scanf("%d%d", &_, &t);
while (t --) {
scanf("%d", &n);
for (int i = 1; i <= n; i ++)
g[i].clear();
for (int i = 1, a, b; i < n; i ++) {
scanf("%d%d", &a, &b);
g[a].push_back(b);
g[b].push_back(a);
}
ans = 0;
dfs1(1, 1);
dfs2(1, 1);
i64 tot = 0;
for (int i = 2; i <= n; i ++)
(tot += sz[i]) %= mod;
printf("%lld\n", (tot - ans + mod) % mod);
}
return 0;
}
::::