P15649 [省选联考 2026] 找寻者 / recollector 题解
这题其实并不困难,我用了五分钟就口胡出了做法,但是一个半小时才写好。大致思路就是树上背包,可撤背包。
1
题目要求每个点到根经过的轻边数量。根据期望的线性性,只要求每条边是轻边的概率。不妨用
2
根据题意,肯定自底向上 dp,肯定需要维护关于链长的信息,因此设
大概就是对于一个结点
这就是套路的可撤销背包。看完这个题解可以去做 DFS Order 2 这道题。
::::info[什么是可撤销背包] 背包的过程其实是把很多个序列卷积在一起。我们知道多项式卷积是有交换律和结合律的,这就可以解释为什么做背包的时候,加入物品的顺序不影响结果。
这样一来,我们可以假装每一个物品都是最后一个加入的,这样一来,把刚才加入这个物品的过程倒过来,就可以撤回这个物品。
从多项式的角度来说,可以是把这个序列求逆,再卷积上去。但是在这题显然不能写 ntt,只需要暴力做即可。 ::::
反正就是我们可以先用
3
有人说,这题不是传统的可撤销背包,因为转移的时候乘上了系数
形式化地说,假设把结点
我们可以用小学三年级就学过的竖式乘法知识。我们假设三个序列最低位都是非零的。
我们发现整个过程只需要求
刚才我们假设了三个多项式最低位都非零,实际问题中最低位一定是零,这也没关系。我们统一移动若干位就可以了。因此额外使用
4
有了这些信息,我们就可以很轻松地算出
::::info[核心伪代码]
let now = 0;
// 先算把所有儿子全算上的背包
for v in son of u:
for i = [now, 0](-1):
// 偷懒的背包写法
let x = f[i]
f[i] = 0
for j = [0, mx[v]] f[i + j] += x * dp[v][j]
now += mx[v]
for v in son of u:
// 计算 H 最低非零位置的逆元,用于背包撤回
Z invx = dp[v][mn[v]].inv()
// 把儿子 v 撤掉,结果存到 g 数组
for i = [0, now - mx[v]]:
g[i] = f[i + mn[v]] * invx
for j = [mn[v], mx[v]]: f[i + j] -= g[i] * dp[v][j]
// 计算边 p_v 和 dp_u
for i = [0, now - mx[v]]:
g[i] = f[i + mn[v]] * invx
for j = [mn[v], mx[v]]:
// 除了结点 v 以外 u 的其他儿子的链长之和为 i,结点 v 的链长为 j 时的概率贡献
let x = g[i] * dp[v][j] * inv[i + j] * j
dp[u][j + 1] += x
p[v] += x
// 把儿子 v 加回来
for i = [0, now - mx[v]]:
for j = [mn[v], mx[v]]: f[i + j] -= g[i] * dp[v][j]
mn[u] = find leftmost non-zero position in dp[u]
mx[u] = find rightmost non-zero position in dp[u]
::::
5
::::info[说一嘴复杂度]
可能有人还不知道,为啥每个节点都算了这么一大坨,还是总共
这是常见结论。在结点
我们从实际含义来理解一下。就是
6
::::info[赛后复刻的完整代码]
#include <bits/stdc++.h>
using namespace std;
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef __int128 lll;
typedef __float128 lld;
typedef pair <int, int> pii;
typedef pair <ll, ll> pll;
const int MOD = 998244353;
#define len(x) ((int)x.size())
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
template <typename T> T qpow(T a, ll b = -1){
T res = 1;
while (b){
if (b & 1) res *= a;
a *= a, b >>= 1;
}
return res;
}
#pragma region modint
template <const int mod> int norm(int x){ return (x + (x < 0) * mod - (x >= mod) * mod); }
template <const int mod> struct modint{
int x = 0;
int val(){ return x; }
modint() = default;
modint(ll _x) : x(norm <mod> (int(_x % mod))){};
modint& operator += (modint v){ x = norm <mod> (x + v.x); return *this; }
friend modint operator - (modint u){ u.x = norm <mod> (-u.x); return u; }
modint& operator -= (modint v){ x = norm <mod> (x - v.x); return *this; }
modint& operator *= (modint v){ x = int(ll(x) * v.x % mod); return *this; }
modint inv(){ return qpow(*this, mod - 2); }
modint& operator /= (modint v){ return *this *= v.inv(); }
friend modint operator + (modint u, modint v){ return u += v; }
friend modint operator - (modint u, modint v){ return u -= v; }
friend modint operator * (modint u, modint v){ return u *= v; }
friend modint operator / (modint u, modint v){ return u /= v; }
friend istream& operator >> (istream& is, modint &u){ ll x; is >> x; u = modint <mod> (x); return is; }
friend ostream& operator << (ostream& os, modint u){ return os << u.x; }
};
using Z = modint <MOD>;
#pragma endregion
int n;
vector <int> e[5005];
Z inv[5005];
Z dp[5005][5005];
Z f[5005], g[5005];
Z p[5005];
int sz[5005];
int mn[5005], mx[5005];
Z sum[5005], ans;
void dfs(int u, int par){
sz[u] = 1;
if (len(e[u]) - (par != -1) == 0){
dp[u][1] = 1;
mn[u] = mx[u] = 1;
return ;
}
for (auto v: e[u]) if (v != par){
dfs(v, u);
sz[u] += sz[v];
}
for (int i = 0; i <= sz[u]; i++) f[i] = 0, dp[u][i] = 0;
f[0] = 1;
int now = 0;
for (auto v: e[u]) if (v != par){
for (int i = now; i >= 0; i--){
Z x = f[i]; f[i] = 0;
for (int j = 0; j <= mx[v]; j++) f[i + j] += x * dp[v][j];
}
now += mx[v];
}
for (auto v: e[u]) if (v != par){
Z invx = dp[v][mn[v]].inv();
p[v] = 0;
for (int i = 0; i <= now - mx[v]; i++){
g[i] = f[i + mn[v]] * invx;
for (int j = mn[v]; j <= mx[v]; j++){
Z x = g[i] * dp[v][j];
f[i + j] -= x;
Z y = x * inv[i + j] * j;
dp[u][j + 1] += y;
p[v] += y;
}
}
for (int i = 0; i <= now - mx[v]; i++){
for (int j = mn[v]; j <= mx[v]; j++) f[i + j] += g[i] * dp[v][j];
}
}
mn[u] = 1e9, mx[u] = -1e9;
for (int i = 0; i <= sz[u]; i++){
if (dp[u][i].val() != 0){
mn[u] = i;
break;
}
}
for (int i = sz[u]; i >= 0; i--){
if (dp[u][i].val() != 0){
mx[u] = i;
break;
}
}
}
void push(int u, int par){
for (auto v: e[u]) if (v != par){
sum[v] = sum[u] + 1 - p[v];
ans += sum[v];
push(v, u);
}
}
void solve(int tid){
int n; cin >> n;
for (int i = 1; i <= n; i++) e[i].clear();
for (int i = 0; i < n - 1; i++){
int u, v; cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, -1);
ans = 0;
push(1, -1);
cout << ans << "\n";
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
for (int i = 1; i <= 5000; i++) inv[i] = Z(i).inv();
int c, t; cin >> c >> t;
for (int i = 0; i < t; i++) solve(i);
return 0;
}
::::