题解:P14780 [COCI 2025/2026 #3] 国家 / Drzava

· · 题解

并且任意两座城市之间的最短路经过的道路数都小于 36。

前排提示:题面中这句话保证了给出的是一棵直径不超过 36 的树,但实际上这个题有一个完全不需要这个性质的 \mathcal O(n^2) 做法,如果你因为这个性质误入歧途可以重新思考一下。

考虑先不管首都在哪个点,设次级城市的点集为 $S(|S|\ge 2)$。如果 $\exist x,y,z\in S$,满足 $y$ 在 $x$ 到 $z$ 的路径上,则这个 $S$ 一定不合法。如果 $|S|=2$ 且 $S$ 中两个点相邻,此时也找不到合法的首都。此外,可以说明一定存在一个点 $x$,使得 $x$ 可以作为首都,也就是将 $x$ 设为树根时 $S$ 中任意两点没有祖先后代关系。证明考虑以 $1$ 为根的情况下 $S$ 的形态,有两种情况: 1. 如果此时 $S$ 中任意两点没有祖先后代关系,则 $1$ 就是合法的 $x$; 2. 否则,设点 $y\in S$,$y$ 的一个子节点为 $x$,且 $x$ 的子树中有 $S$ 中的点。此时可以发现,由于 $S$ 中任意两点路径不能经过 $y$,所以 $S$ 中其它的点都在 $x$ 子树中且这些点之间没有祖先后代关系。同时由假设 $x\notin S$,所以这就是一个可以作为首都的 $x$。 此时,任何一个不在 $S$ 中某一点子树中的 $y$ 也可做为首都(等价的表述是 $x$ 到 $y$ 的路径上没有 $S$ 中的点,或者删除 $S$ 中的点后 $x$ 和 $y$ 仍然连通),即可作为首都的点的数量为 $n-\sum_{i\in S}sz_i$。注意这里的 $sz_i$ 指的是以 $x$ 为根时 $i$ 的子树大小,同时将 $x$ 换成任意一个可以作为首都的 $y$ 不会改变 $S$ 中的点的 $sz$。 至此,dp 的状态设计和转移设计都已经有了。在以 $1$ 为根的情况下,设 $g_{x,i}$ 表示 $x$ 的子树中选了 $i$ 个点到 $S$ 中,且它们都没有祖先后代关系,$f_{x,i}$ 表示所有这些方案中 $\sum sz_i$ 的和,在这种情况下 $sz_i=siz_i$,也就是以首都为根时的子树大小等于以 $1$ 为根时的子树大小。转移只需要考虑三种情况: 1. 合并 $x$ 的子树; 2. 将 $x$ 选入 $S$ 中,且 $x$ 的子树中没有其它 $S$ 中的点; 3. 将 $x$ 选入 $S$ 中,且 $x$ 的子树中有其它的点。 其中第一种和第二种都是常规的树形背包。对于第三种,此时 $S$ 中所有点都已经确定了。枚举 $y,i$,其中 $y$ 为 $x$ 的子节点,表示 $S$ 中其它的 $i$ 个点都在 $y$ 子树中。此时点 $y$ 可以作为首都,且对应的 $sz_x=n-siz_y$,所以计算 $G=g_{y,i},F=f_{y,i}+g_{y,i}\times(n-siz_y)$,将 $n\times G-F$ 计入 $ans_{i+2}$ 即可($i+2$ 是加上点 $x$ 和首都)。 最后考虑 $S$ 中任意两点没有祖先后代关系的情况,此时 $1$ 可以作为首都,故对于每个 $i$,将 $n\times g_{1,i}-f_{1,i}$ 计入 $ans_{i+1}$ 即可。 前两种转移由树形背包结论复杂度为 $\mathcal O(n^2)$,第三种复杂度为 $(y,i)$ 的数量也是 $\mathcal O(n^2)$,总复杂度 $\mathcal O(n^2)$。 ```cpp {.line-numbers} #include <bits/stdc++.h> typedef long long LL; typedef __int128 LLL; typedef unsigned long long ULL; typedef std::pair<int, int> pii; typedef long double RN; #define fi first #define se second #define MP std::make_pair #define EB emplace_back LL read() { LL s = 0; int f = 1, c = getchar(); for (; !isdigit(c); c = getchar()) f ^= (c == '-'); for (; isdigit(c); c = getchar()) s = s * 10 + (c ^ 48); return f ? s : -s; } template<typename T> void write(T x, char end = '\n') { if (x < 0) x = -x, putchar('-'); static int d[100], cur = 0; do { d[++cur] = x % 10; } while (x /= 10); while (cur) putchar(48 ^ d[cur--]); putchar(end); } const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3fll; template<typename T> void Fmin(T &x, T y){ if (y < x) x = y; } template<typename T> void Fmax(T &x, T y){ if (x < y) x = y; } const int MOD = 1e9 + 7; int fplus(int x, int y){ if ((x += y) >= MOD) return x - MOD; return x; } int fminus(int x, int y){ if ((x -= y) < 0) return x + MOD; return x; } void Fplus(int &x, int y){ if ((x += y) >= MOD) x -= MOD; } void Fminus(int &x, int y){ if ((x -= y) < 0) x += MOD; } int fpow(int x, int y = MOD - 2) { int res = 1; for (; y; y >>= 1, x = (LL)x * x % MOD) if (y & 1) res = (LL)res * x % MOD; return res; } const int MAXN = 3005; int n, siz[MAXN]; int f[MAXN][MAXN], g[MAXN][MAXN]; int ans[MAXN]; std::vector<int> e[MAXN]; void dfs(int x, int fat) { for (int y : e[x]) if (y != fat) dfs(y, x); f[x][0] = 0, g[x][0] = 1; static int tf[MAXN], tg[MAXN], sf[MAXN], sg[MAXN]; memset(sf, 0, n << 2); memset(sg, 0, n << 2); for (int y : e[x]) if (y != fat) { memcpy(tf, f[x], (siz[x] + 1) << 2), memset(f[x], 0, (siz[x] + 1) << 2); memcpy(tg, g[x], (siz[x] + 1) << 2), memset(g[x], 0, (siz[x] + 1) << 2); for (int i = 0; i <= siz[x]; i++) for (int j = 0; j <= siz[y]; j++) { Fplus(g[x][i + j], (LL)tg[i] * g[y][j] % MOD); Fplus(f[x][i + j], ((LL)tg[i] * f[y][j] + (LL)tf[i] * g[y][j]) % MOD); } siz[x] += siz[y]; for (int i = 1; i <= siz[y]; i++) { int G = g[y][i], F = (f[y][i] + (LL)G * (n - siz[y])) % MOD; int res = fminus((LL)n * G % MOD, F); Fplus(ans[i + 2], res); } } Fplus(ans[2], siz[x]); ++siz[x], Fplus(f[x][1], siz[x]), Fplus(g[x][1], 1); // printf("---------- %d ------------\n", x); // for (int i = 0; i <= siz[x]; i++) printf("%d%c", f[x][i], " \n"[i == siz[x]]); // for (int i = 0; i <= siz[x]; i++) printf("%d%c", g[x][i], " \n"[i == siz[x]]); // printf("--------------------------\n"); } int main() { n = read(); for (int i = 1; i < n; i++) { int u = read(), v = read(); e[u].push_back(v), e[v].push_back(u); } dfs(1, 0); for (int i = 0; i < n; i++) Fplus(ans[i + 1], fminus((LL)n * g[1][i] % MOD, f[1][i])); for (int i = 1; i <= n; i++) write(ans[i], " \n"[i == n]); return 0; } ```