CF1868C 题解

· · 题解

很好写的 O(\sum m \log n) 做法,目前 solution size 排第三。

因为 m 比较小,所以考虑枚举 \max。容斥一下,设 F_x 表示对于每条链,链上的点 \max \le x,其他点任意填的方案数,即 \max 恰好等于 i 的方案数即为 ans_x = F_x - F_{x - 1}。答案为 \sum_{x = 1}^m x(F_x - F_{x - 1}) = F_m - \sum_{x = 1}^{m - 1} F_x

考虑树形 DP,设 $f_k$ 表示大小为 $k$ 的树内的路径填法和,$g_k$ 表示到大小为 $k$ 树的内**到根**的路径填法和,有转移方程: $$g_k = x(g_{ls}m^{rs} + m^{ls}g_{rs} + m^{k - 1})$$ $$f_k = g_k + xg_{ls}g_{rs} + f_{ls}m^{rs + 1} + m^{ls + 1}f_{rs}$$ 其中 $ls, rs$ 分别为左右子树大小。状态数看似是 $O(n)$ 的,但是实际上左右两边至少一边是 $2^y - 1$ 的形式,所以对 $f_n$ 有用的状态数实际上是 $O(\log n)$ 的。 转移时如果用快速幂求出系数,就会多出一个 $\log$,但是我们对用到的状态都预处理出 $m^k$ 就可以 $O(\log n)$ 求出 $f_n$。 我们对于每一个 $x$,都 $O(\log n)$ 求出 $F_x$(即当前 $x$ 求出的 $f_n$),就可以 $O(m \log n)$ 求出原本的答案了。 实现时不需要开一个 `unordered_map` 去记忆化搜索,只需要记录所有 $k = 2^y - 1$ 的状态,时间复杂度就是对的。 代码真的很好写: ```cpp #include <unordered_map> #include <algorithm> #include <cstring> #include <cstdio> #define mod 998244353 #define popc __builtin_popcountll #define log2 __builtin_ctzll using namespace std; const int N = 100001; typedef long long ll; int t, n, m, ans[N], sum;ll k; struct node{int vf, vg, vp;}F[64]; inline void add(int& x, int y){x += y;if (x >= mod) x -= mod;} node dp(ll x){ if (!x) return {0, 0, 1}; if (x == 1) return {m, m, n};x ++; if (popc(x) == 1 && F[log2(x)].vp) return F[log2(x)];x --; ll ri = 1ll << 64 - __builtin_clzll(x); ri --;ll lx = ri - 1 >> 1, rx = lx; ll ls = ri - x, di = ri + 1 >> 2; if (ls <= di) rx -= ls; else rx -= di, lx -= ls - di; node lf = dp(lx), rf = dp(rx), res; res.vp = 1ll * lf.vp * rf.vp % mod; res.vg = (1ll * lf.vg * rf.vp + 1ll * rf.vg * lf.vp + res.vp) % mod * m % mod; res.vp = 1ll * res.vp * n % mod; res.vf = (1ll * lf.vg * rf.vg % mod * m + res.vg + (1ll * lf.vf * rf.vp + 1ll * rf.vf * lf.vp) % mod * n) % mod; x ++;if (popc(x) == 1) F[log2(x)] = res;x --;return res; } int main(){scanf("%d", &t); while (t --){scanf("%lld%d", &k, &n), sum = 0; for (m = 1;m <= n;m ++) memset(F, 0, sizeof(F)), ans[m] = dp(k).vf; for (m = n;m >= 1;m --) add(ans[m], mod - ans[m - 1]), add(sum, 1ll * ans[m] * m % mod); printf("%d\n", sum); } } ```