P15649 [省选联考 2026] 找寻者 / recollector 题解

· · 题解

这题其实并不困难,我用了五分钟就口胡出了做法,但是一个半小时才写好。大致思路就是树上背包,可撤背包。

1

题目要求每个点到根经过的轻边数量。根据期望的线性性,只要求每条边是轻边的概率。不妨用 p_i 表示点 i 到父亲的边是重边的概率。则答案就是 \sum_u \sum_{v|v \text{ is an ancestor of } u} p_v 。这部分随便做,所以重点就是求 p_i

2

根据题意,肯定自底向上 dp,肯定需要维护关于链长的信息,因此设 dp_{i, j} 表示结点 i 的重链向下延伸长度为 j 的概率。如此一来,对于边 (u, v),这里 uv 父亲,它的答案就是 p_v = \sum_{\{l_i\}} \prod_{w \text{ is a son of }u} dp_{w,l_w}\times {l_v \over {\sum l_w}}

大概就是对于一个结点 u,你需要枚举每一个儿子 v,把剩下儿子 w\{dp_{w, i}\} 序列背包在一起(这里是名词活用作动词),算出对每个 x 算出 g_x 表示结点 u 除了 v 以外的儿子中选总长为 x 的链的概率之和,再枚举结点 v 的链长,乘一乘。

这就是套路的可撤销背包。看完这个题解可以去做 DFS Order 2 这道题。

::::info[什么是可撤销背包] 背包的过程其实是把很多个序列卷积在一起。我们知道多项式卷积是有交换律和结合律的,这就可以解释为什么做背包的时候,加入物品的顺序不影响结果

这样一来,我们可以假装每一个物品都是最后一个加入的,这样一来,把刚才加入这个物品的过程倒过来,就可以撤回这个物品。

从多项式的角度来说,可以是把这个序列求逆,再卷积上去。但是在这题显然不能写 ntt,只需要暴力做即可。 ::::

反正就是我们可以先用 u 的全部儿子算背包,再挨个把儿子 v 撤掉,计算,再加回去。

3

有人说,这题不是传统的可撤销背包,因为转移的时候乘上了系数 dp_{w, l_w}。其实是不影响的。我们从低位到高位考虑。

形式化地说,假设把结点 u 的所有儿子的 dp_{u, j} 卷积在一起得到的序列是 F,儿子 v\{dp_{v_i}\} 序列记为 H。我们要求 G,使得 G \cdot H = F

我们可以用小学三年级就学过的竖式乘法知识。我们假设三个序列最低位都是非零的。F 的最低位一定只来自于 GH 的最低位。因此我们直接用 FH 的最低位算出 G 的最低位,也就是 G_0 = {F_0 \over H_0}。接下来,我们从 F 中减去 G_0H,然后 F 的最低非零位就变成了 F_1,还是 G_1 = {F_1\over H_0},然后从 F 当中减去 G_1H

我们发现整个过程只需要求 H_0 的逆元就可以了。而且复杂度的建模还是类似于树上背包,复杂度是 O(n^2 + n \log n) = O(n^2)

刚才我们假设了三个多项式最低位都非零,实际问题中最低位一定是零,这也没关系。我们统一移动若干位就可以了。因此额外使用 mn_vmx_v 存结点 v\{dp_{v, i}\} 序列最低和最高的非零位。根据含义,这个序列的和一定是 1,因此一定能找到非零位置,也不存在能够卡到全部数值恰好都是 998244353 的倍数的数据卡爆逆元。

4

有了这些信息,我们就可以很轻松地算出 dp_{u, i}p_v,我直接放伪代码你们就能看懂。

::::info[核心伪代码]

let now = 0;
// 先算把所有儿子全算上的背包
for v in son of u:
    for i = [now, 0](-1):
        // 偷懒的背包写法
        let x = f[i]
        f[i] = 0
        for j = [0, mx[v]] f[i + j] += x * dp[v][j]
    now += mx[v]

for v in son of u:
    // 计算 H 最低非零位置的逆元,用于背包撤回
    Z invx = dp[v][mn[v]].inv()
    // 把儿子 v 撤掉,结果存到 g 数组
    for i = [0, now - mx[v]]:
        g[i] = f[i + mn[v]] * invx
        for j = [mn[v], mx[v]]: f[i + j] -= g[i] * dp[v][j]
    // 计算边 p_v 和 dp_u
    for i = [0, now - mx[v]]:
        g[i] = f[i + mn[v]] * invx
        for j = [mn[v], mx[v]]:
            // 除了结点 v 以外 u 的其他儿子的链长之和为 i,结点 v 的链长为 j 时的概率贡献
            let x = g[i] * dp[v][j] * inv[i + j] * j
            dp[u][j + 1] += x
            p[v] += x
    // 把儿子 v 加回来
    for i = [0, now - mx[v]]:
        for j = [mn[v], mx[v]]: f[i + j] -= g[i] * dp[v][j]

mn[u] = find leftmost non-zero position in dp[u]
mx[u] = find rightmost non-zero position in dp[u]

::::

5

::::info[说一嘴复杂度] 可能有人还不知道,为啥每个节点都算了这么一大坨,还是总共 O(n^2) 的。

这是常见结论。在结点 u 处,依次加入每个儿子 v,复杂度计算量是 O(sz_v \cdot \sum_{w \text{ is added before } v} sz_w) 或者 O(sz_v \cdot (sz_u - sz_v))。这两个东西是一样的,因为前面的东西正反跑两边就是后面的东西。

我们从实际含义来理解一下。就是 u 的子树中,不在同一子树中的点对被算了 O(1) 次,在同一子树内的点对不被计算。反过来想,整棵树上的所有点对,都只在他们的 lca 处被算过 O(1) 次,这样整个的复杂度也是 O(n^2) 的了。 ::::

6

::::info[赛后复刻的完整代码]

#include <bits/stdc++.h>
using namespace std;
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef __int128 lll;
typedef __float128 lld;
typedef pair <int, int> pii;
typedef pair <ll, ll> pll;
const int MOD = 998244353;
#define len(x) ((int)x.size())
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()

template <typename T> T qpow(T a, ll b = -1){
    T res = 1;
    while (b){
        if (b & 1) res *= a;
        a *= a, b >>= 1;
    }
    return res;
}
#pragma region modint
template <const int mod> int norm(int x){ return (x + (x < 0) * mod - (x >= mod) * mod); }
template <const int mod> struct modint{
    int x = 0;
    int val(){ return x; }
    modint() = default;
    modint(ll _x) : x(norm <mod> (int(_x % mod))){};
    modint& operator += (modint v){ x = norm <mod> (x + v.x); return *this; }
    friend modint operator - (modint u){ u.x = norm <mod> (-u.x); return u; }
    modint& operator -= (modint v){ x = norm <mod> (x - v.x); return *this; }
    modint& operator *= (modint v){ x = int(ll(x) * v.x % mod); return *this; }
    modint inv(){ return qpow(*this, mod - 2); }
    modint& operator /= (modint v){ return *this *= v.inv(); }
    friend modint operator + (modint u, modint v){ return u += v; }
    friend modint operator - (modint u, modint v){ return u -= v; }
    friend modint operator * (modint u, modint v){ return u *= v; }
    friend modint operator / (modint u, modint v){ return u /= v; }
    friend istream& operator >> (istream& is, modint &u){ ll x; is >> x; u = modint <mod> (x); return is; }
    friend ostream& operator << (ostream& os, modint u){ return os << u.x; }
};
using Z = modint <MOD>;
#pragma endregion

int n;
vector <int> e[5005];
Z inv[5005];
Z dp[5005][5005];
Z f[5005], g[5005];
Z p[5005];
int sz[5005];
int mn[5005], mx[5005];
Z sum[5005], ans;

void dfs(int u, int par){
    sz[u] = 1;
    if (len(e[u]) - (par != -1) == 0){
        dp[u][1] = 1;
        mn[u] = mx[u] = 1;
        return ;
    }
    for (auto v: e[u]) if (v != par){
        dfs(v, u);
        sz[u] += sz[v];
    }

    for (int i = 0; i <= sz[u]; i++) f[i] = 0, dp[u][i] = 0;
    f[0] = 1;

    int now = 0;
    for (auto v: e[u]) if (v != par){
        for (int i = now; i >= 0; i--){
            Z x = f[i]; f[i] = 0;
            for (int j = 0; j <= mx[v]; j++) f[i + j] += x * dp[v][j];
        }
        now += mx[v];
    }

    for (auto v: e[u]) if (v != par){
        Z invx = dp[v][mn[v]].inv();
        p[v] = 0;
        for (int i = 0; i <= now - mx[v]; i++){
            g[i] = f[i + mn[v]] * invx;
            for (int j = mn[v]; j <= mx[v]; j++){
                Z x = g[i] * dp[v][j];
                f[i + j] -= x;
                Z y = x * inv[i + j] * j;
                dp[u][j + 1] += y;
                p[v] += y;
            }
        }
        for (int i = 0; i <= now - mx[v]; i++){
            for (int j = mn[v]; j <= mx[v]; j++) f[i + j] += g[i] * dp[v][j];
        }
    }
    mn[u] = 1e9, mx[u] = -1e9;
    for (int i = 0; i <= sz[u]; i++){
        if (dp[u][i].val() != 0){
            mn[u] = i;
            break;
        }
    }
    for (int i = sz[u]; i >= 0; i--){
        if (dp[u][i].val() != 0){
            mx[u] = i;
            break;
        }
    }
}
void push(int u, int par){
    for (auto v: e[u]) if (v != par){
        sum[v] = sum[u] + 1 - p[v];
        ans += sum[v];
        push(v, u);
    }
}

void solve(int tid){
    int n; cin >> n;
    for (int i = 1; i <= n; i++) e[i].clear();
    for (int i = 0; i < n - 1; i++){
        int u, v; cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1, -1);
    ans = 0;
    push(1, -1);
    cout << ans << "\n";
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    for (int i = 1; i <= 5000; i++) inv[i] = Z(i).inv();

    int c, t; cin >> c >> t;
    for (int i = 0; i < t; i++) solve(i);
    return 0;
}

::::