题解:P15649 [省选联考 2026] 找寻者 / recollector

· · 题解

首先可以注意到,期望的和就是和的期望,因此不妨从每个点到祖先上轻边数量的总和考虑。

发现带着轻边数量难以设计优秀的状态,不妨考虑交换贡献形式,可以发现每条轻边仅会在轻边父子中儿子对应的子树产生贡献,因此可以将“每个点到祖先的轻边数量和”变成“所有轻边中儿子的子树大小和”。下面为了讨论方便,将统计目标变成了“所有重边中儿子的子树大小和”,因为一个点只会往下延伸一条重边,统计起来更方便一些。

期望值可以在每个点上统计,只需要枚举这个点的重儿子,算出取到这个重儿子对应的概率,与重儿子子树大小相乘后累加即可。因此问题进一步转化成确认每个儿子作为重儿子的概率。这样就终于可以只在概率而非期望的角度考虑这个问题了。

f_{x, d} 表示对于点 x,其向下的重链长度为 d 对应的概率,其生成函数记为 F_x (u) = \sum_{d=0}^{\infin} f_{x, d} \times u^d。另外记点 x 的所有儿子为 v_1, v_2, \dots, v_k。假设 x 选择 v_p 作为重儿子,那么需要计算其他儿子的重链长度之和为某个值对应的概率,也就是计算

G_{v_p} (u) = \prod_{i \neq p} F_{v_i} (u) = \dfrac {\prod_i F_{v_i} (u)} {F_{v_p} (u)}

对于 \prod_i F_{v_i} (u),可以直接使用 \mathcal{O} (l_1 l_2) 的卷积暴力计算,树型动态规划的性质告诉我们这一部分的时间复杂度是 \mathcal{O} (n^2) 的。而对于多项式除法,考虑如下算法:

::::info[多项式除法]{open} 不妨先想想,竖式除法是怎么确定结果的每一位的?

假设需要将长度为 l_1 的多项式 P(x) 除以长度为 l_2 的多项式 Q(x)。先将 P(x)Q(x) 同时除以 x 的某一个幂次,使得 Q(x) 的常数项非零。随后从低到高考虑除法结果的系数,只需要将当前 P(x) 对应临时结果的最低位乘以 [x^0] Q(x) 的逆元就能得到,随后将系数乘以 Q(x) 从临时结果中减去即可。

结果多项式为 l_1 - l_2 位多项式,而计算其中每一个系数的时间复杂度为 \mathcal{O} (l_2),因此上述算法的时间复杂度就是 \mathcal{O} ((l_1 - l_2) l_2)。 ::::

在计算除法的结果后,直接枚举 v_p 的重链长度 a 和其他儿子的重链长度 b,根据题目给出的概率可以得到贡献:

\begin{align} f_{x, a + 1} &\leftarrow \dfrac{a}{a + b} \times [u^a] F_{v_p} (u) \times [u^b] G_{v_p} (u) \\ ans &\leftarrow \dfrac{a}{a + b} \times [u^a] F_{v_p} (u) \times [u^b] G_{v_p} (u) \times \text{sz}_{v_p} \end{align}

多项式除法的计算量和上述转移对应是相同的,因此只需要考虑上述转移的总时间复杂度。注意到 1 \leq a \leq \text{sz}_{v_p}1 \leq b \leq \text{sz}_x - \text{sz}_{v_p},因此对于每个点 v,其在父亲 \text{fa}_v 处统计贡献时所需的时间复杂度就是 \mathcal{O} ((\text{sz}_{\text{fa}_v} - \text{sz}_v) \text{sz}_v)

最后,考虑和式

\sum_{v=2}^n (\text{sz}_{\text{fa}_v} - \text{sz}_v) \text{sz}_v

这个和式对于每个点 v 等价于将其子树内的点和 v 的兄弟子树,以及 v 的父结点相互匹配,因此每个点对只会在它们的最近公共祖先上计算一次,故上述和式的结果是 \mathcal{O} (n^2) 的。这样,通过上述转移就能在 \mathcal{O} (\sum n^2) 的时间复杂度内解决这个问题。

::::info[代码]

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
using namespace std;

using i64 = long long;
const int mod = 998244353;
i64 pw(int x, int y) {
  i64 res = 1, bas = x;
  while (y) {
    if (y & 1)
      res = res * bas % mod;
    bas = bas * bas % mod;
    y >>= 1;
  }
  return res;
}

int t, n;
vector<int> g[5010];

i64 f[5010][5010], tmp[5010], tmp2[5010];
i64 ans = 0;
int sz[5010];
int mx[5010];
i64 inv[5010];

void dfs1(int x, int fa) {
  sz[x] = mx[x] = 1;
  for (auto v : g[x]) if (v != fa) {
    dfs1(v, x);
    sz[x] += sz[v];
    mx[x] = max(mx[x], mx[v] + 1);
  }
}

void dfs2(int x, int fa) {
  if (sz[x] == 1) {
    f[x][0] = 0;
    f[x][1] = 1;
    return;
  }

  // 第一部分:合并子树多项式
  int len = 0;
  f[x][0] = 1;
  for (auto v : g[x]) if (v != fa) {
    dfs2(v, x);

    int tl = mx[v];
    for (int i = 0; i <= len + tl; i ++)
      tmp[i] = 0;
    for (int a = 0; a <= len; a ++)
      for (int b = 0; b <= tl; b ++)
        (tmp[a + b] += f[x][a] * f[v][b]) %= mod;

    for (int i = 0; i <= len + tl; i ++)
      f[x][i] = tmp[i];
    len += tl;
  }

  // 第二部分:模拟多项式除法,并写入结果
  for (int i = 0; i <= mx[x]; i ++)
    f[x][i] = 0;

  for (auto v : g[x]) if (v != fa) {
    i64 expe = 0;

    int ptr = 0;
    while (f[v][ptr] == 0)
      ++ ptr;
    for (int i = ptr; i <= len; i ++)
      tmp2[i] = tmp[i];
    int len2 = mx[v];
    i64 iv = pw(f[v][ptr], mod - 2);

    // tmp2[ptr..len]
    // f[v][ptr..len2]
    for (int i = ptr; i <= len - (len2 - ptr); i ++) {
      // 此时其他点的重链长度为 i - ptr
      i64 coef = tmp2[i] * iv % mod;
      for (int j = 0; j <= len2 - ptr; j ++) {
        // 此时 v 的重链长度为 ptr + j
        i64 contri = coef * f[v][ptr + j] % mod;
        tmp2[i + j] -= contri;
        if (tmp2[i + j] < 0)
          tmp2[i + j] += mod;

        contri = (ptr + j) * inv[i + j] % mod * contri % mod;
        (f[x][ptr + j + 1] += contri) %= mod;

        (expe += contri) %= mod;
      }
    }

    (ans += sz[v] * expe) %= mod;
  }
}

int main() {
  for (int i = 1; i <= 5000; i ++)
    inv[i] = pw(i, mod - 2);
  int _;
  scanf("%d%d", &_, &t);
  while (t --) {
    scanf("%d", &n);
    for (int i = 1; i <= n; i ++)
      g[i].clear();
    for (int i = 1, a, b; i < n; i ++) {
      scanf("%d%d", &a, &b);
      g[a].push_back(b);
      g[b].push_back(a);
    }

    ans = 0;
    dfs1(1, 1);
    dfs2(1, 1);
    i64 tot = 0;
    for (int i = 2; i <= n; i ++)
      (tot += sz[i]) %= mod;
    printf("%lld\n", (tot - ans + mod) % mod);
  }
  return 0;
}

::::