CF1868C 题解
vegetable_king
·
·
题解
很好写的 O(\sum m \log n) 做法,目前 solution size 排第三。
因为 m 比较小,所以考虑枚举 \max。容斥一下,设 F_x 表示对于每条链,链上的点 \max \le x,其他点任意填的方案数,即 \max 恰好等于 i 的方案数即为 ans_x = F_x - F_{x - 1}。答案为 \sum_{x = 1}^m x(F_x - F_{x - 1}) = F_m - \sum_{x = 1}^{m - 1} F_x。
考虑树形 DP,设 $f_k$ 表示大小为 $k$ 的树内的路径填法和,$g_k$ 表示到大小为 $k$ 树的内**到根**的路径填法和,有转移方程:
$$g_k = x(g_{ls}m^{rs} + m^{ls}g_{rs} + m^{k - 1})$$
$$f_k = g_k + xg_{ls}g_{rs} + f_{ls}m^{rs + 1} + m^{ls + 1}f_{rs}$$
其中 $ls, rs$ 分别为左右子树大小。状态数看似是 $O(n)$ 的,但是实际上左右两边至少一边是 $2^y - 1$ 的形式,所以对 $f_n$ 有用的状态数实际上是 $O(\log n)$ 的。
转移时如果用快速幂求出系数,就会多出一个 $\log$,但是我们对用到的状态都预处理出 $m^k$ 就可以 $O(\log n)$ 求出 $f_n$。
我们对于每一个 $x$,都 $O(\log n)$ 求出 $F_x$(即当前 $x$ 求出的 $f_n$),就可以 $O(m \log n)$ 求出原本的答案了。
实现时不需要开一个 `unordered_map` 去记忆化搜索,只需要记录所有 $k = 2^y - 1$ 的状态,时间复杂度就是对的。
代码真的很好写:
```cpp
#include <unordered_map>
#include <algorithm>
#include <cstring>
#include <cstdio>
#define mod 998244353
#define popc __builtin_popcountll
#define log2 __builtin_ctzll
using namespace std;
const int N = 100001;
typedef long long ll;
int t, n, m, ans[N], sum;ll k;
struct node{int vf, vg, vp;}F[64];
inline void add(int& x, int y){x += y;if (x >= mod) x -= mod;}
node dp(ll x){
if (!x) return {0, 0, 1};
if (x == 1) return {m, m, n};x ++;
if (popc(x) == 1 && F[log2(x)].vp) return F[log2(x)];x --;
ll ri = 1ll << 64 - __builtin_clzll(x);
ri --;ll lx = ri - 1 >> 1, rx = lx;
ll ls = ri - x, di = ri + 1 >> 2;
if (ls <= di) rx -= ls;
else rx -= di, lx -= ls - di;
node lf = dp(lx), rf = dp(rx), res;
res.vp = 1ll * lf.vp * rf.vp % mod;
res.vg = (1ll * lf.vg * rf.vp + 1ll * rf.vg * lf.vp + res.vp) % mod * m % mod;
res.vp = 1ll * res.vp * n % mod;
res.vf = (1ll * lf.vg * rf.vg % mod * m + res.vg +
(1ll * lf.vf * rf.vp + 1ll * rf.vf * lf.vp) % mod * n) % mod;
x ++;if (popc(x) == 1) F[log2(x)] = res;x --;return res;
}
int main(){scanf("%d", &t);
while (t --){scanf("%lld%d", &k, &n), sum = 0;
for (m = 1;m <= n;m ++)
memset(F, 0, sizeof(F)), ans[m] = dp(k).vf;
for (m = n;m >= 1;m --)
add(ans[m], mod - ans[m - 1]),
add(sum, 1ll * ans[m] * m % mod);
printf("%d\n", sum);
}
}
```