题解 P9090 【「SvR-2」G64】
Leasier
·
·
题解
来点对注意力要求不高的做法。
Subtask 1 \sim 4
首先可以直接 dp 求出每个子树的最大独立集。
接下来不难发现从 G_{x - 1}(T) 到 G_x(T) 的过程中我们只关心根和右链端点选不选,于是设 dp_{x, i = 0/1, j = 0/1} 表示 G_x(T) 中根选的情况为 i,右链端点选的情况为 j,此时的最大独立集为多少。
暴力实现即可。时间复杂度为 $O(n + qx)$。
### 无特殊限制
类似于 [[SDOI / SXOI2022] 小 N 的独立集](https://www.luogu.com.cn/problem/P8352),我们有结论:
- 当二叉树 $T$ 的右链长度 $> 2$,$f(T, i, j)$ 的极差 $\leq 2$,这里 $f$ 的定义基本同 $dp_{x, i, j}$。
证明:
- 考虑一个独立集去掉一个点还是一个独立集,有 $T(x, 0, i) \geq T(x, 1, i) - 1, T(x, i, 0) \geq T(x, i, 1) - 1$。
- 考虑强行将根或右链端点加入独立集并删去周围点,有 $T(x, 1, i) \geq T(x, 0, i) - 1, T(x, i, 1) \geq T(x, i, 0) - 1$。
- 综合上式,有 $|T(x, 0, i) - T(x, 1, i)| \leq 1, |T(x, i, 0) - T(x, i, 1)| \leq 1$。
- 于是 $|T(x, 0, 0) - T(x, 1, 1)| \leq 2, |T(x, 0, 1) - T(x, 1, 0)| \leq 2$,遂得证。
于是我们可以这样表出一个 $x$ 对应的状态:
- $(base, v_{0/1, 0/1})$,表示 $dp_{x, i, j} = v_{i, j} + base$。
- 其中 $\min(v_{i, j}) = 0, \max(v_{i, j}) \leq 2$。
可见后面 $v_{0/1, 0/1}$ 的状态数 $\leq 81$(实际上打表可知只有 $30$ 个)。
每次转移 $dp_{x - 1, i, j} \to dp_{x, i, j}$ 时,我们可以预处理出其从 $v_{0/1, 0/1}$ 转移到了 $v'_{0/1, 0/1}$,且 $base \leftarrow 4base + \Delta$。
在此基础上倍增即可。时间复杂度为 $O(n + q \log x)$。
我的实现中直接把表放在预处理函数里了(下示代码中略去)。
代码:
```cpp
#include <stdio.h>
typedef long long ll;
const int N = 66, M = 29, K = 1e6 + 7, P = 1 + 1, mod = 998244353;
int to[N + 7][M + 7], delta[N + 7][M + 7], power[M + 7], ls[K], rs[K], dp[K][P][P], g1[K][P][P], temp[P][P];
inline int add(int x, int y){
return x + y >= mod ? x + y - mod : x + y;
}
inline void init(){
// 打表结果放这里
power[0] = 4;
for (register int i = 1; i <= M; i++){
for (register int j = 0; j <= N; j++){
to[j][i] = to[to[j][i - 1]][i - 1];
delta[j][i] = add(1ll * delta[j][i - 1] * power[i - 1] % mod, delta[to[j][i - 1]][i - 1]);
}
power[i] = 1ll * power[i - 1] * power[i - 1] % mod;
}
}
inline int read(){
int ans = 0;
char ch = getchar();
while (ch < '0' || ch > '9'){
ch = getchar();
}
while (ch >= '0' && ch <= '9'){
ans = ans * 10 + (ch ^ 48);
ch = getchar();
}
return ans;
}
inline int max(int a, int b){
return a > b ? a : b;
}
void dfs(int u){
if (ls[u] != 0) dfs(ls[u]);
if (rs[u] != 0) dfs(rs[u]);
int p = max(dp[ls[u]][0][0], dp[ls[u]][0][1]), q = max(p, max(dp[ls[u]][1][0], dp[ls[u]][1][1]));
if (rs[u] == 0){
dp[u][0][0] = max(p, max(dp[ls[u]][1][0], dp[ls[u]][1][1]));
dp[u][0][1] = dp[u][1][0] = 0x80000000;
dp[u][1][1] = p + 1;
} else {
dp[u][0][0] = q + max(dp[rs[u]][0][0], dp[rs[u]][1][0]);
dp[u][0][1] = q + max(dp[rs[u]][0][1], dp[rs[u]][1][1]);
dp[u][1][0] = p + dp[rs[u]][0][0] + 1;
dp[u][1][1] = p + dp[rs[u]][0][1] + 1;
}
for (register int i = 0; i <= 1; i++){
for (register int j = 0; j <= 1; j++){
g1[u][i][j] = max(dp[u][i][0] + max(dp[u][0][j], dp[u][1][j]), dp[u][i][1] + dp[u][0][j]);
}
}
}
inline int min(int a, int b){
return a < b ? a : b;
}
inline void trans(int f[P][P], int &base, int &state){
for (register int i = 0; i <= 1; i++){
for (register int j = 0; j <= 1; j++){
temp[i][j] = 0x80000000;
}
}
for (register int i = 0; i <= 1; i++){
for (register int j = 0; j <= 1; j++){
for (register int k = 0; i + k <= 1; k++){
for (register int l = 0; l <= 1; l++){
for (register int x = 0; i + x <= 1; x++){
for (register int y = 0; y <= 1; y++){
for (register int z = 0; y + z <= 1; z++){
for (register int w = 0; z + w <= 1; w++){
for (register int p = 0; p <= 1; p++){
for (register int q = 0; z + q <= 1; q++){
temp[i][j] = max(temp[i][j], f[k][l] + f[x][y] + f[w][p] + f[q][j] + i + z);
}
}
}
}
}
}
}
}
}
}
base = min(temp[0][0], min(temp[0][1], min(temp[1][0], temp[1][1])));
state = (temp[0][0] - base) + (temp[0][1] - base) * 3 + (temp[1][0] - base) * 9 + (temp[1][1] - base) * 27;
}
int main(){
int n = read(), q = read();
init();
for (register int i = 1; i <= n; i++){
ls[i] = read();
rs[i] = read();
}
dfs(1);
for (register int I = 1; I <= q; I++){
int x = read(), i = read();
if (--x == 0){
printf("%d\n", max(g1[i][0][0], max(g1[i][0][1], max(g1[i][1][0], g1[i][1][1]))));
continue;
}
int base, state;
trans(g1[i], base, state);
x--;
for (register int j = 0; (1 << j) <= x; j++){
if (x >> j & 1){
base = add(1ll * base * power[j] % mod, delta[state][j]);
state = to[state][j];
}
}
printf("%d\n", max(state % 3, max(state / 3 % 3, max(state / 9 % 3, state / 27))) + base);
}
return 0;
}
```