【题解】CF1499F Diameter Cuts

· · 题解

提供一个比较笨的 DP 方法。

f_{u,i} 表示从 u 的父亲走到 u 再一直向下走,途中只能走未被割的边,所形成的最长路径的边数,称为 u 的深度。(u 到父亲的边被割掉时,不能形成路径,故 i=0。)

我们来处理 f_{u,i+1} 的转移。考虑把子结点排成一列,其中深度最大的结点为 v(有多个则选最右边的),则左边的结点深度 \le i,右边的 <i。由题意得其余结点的深度不超过 m-i。记 p(v) 为子结点 v 排成一列的编号,则转移方程为

f_{u,i+1}&=\sum_{v\in son(u)}\prod_{v'\in son(u),v'\ne v}\sum_{i'=1}^{\min(i-[p(v')<p(v)],m-i)}f_{v,i'},\\ f_{u,0}&=\sum_{i=0}^{m}f_{u,i+1}. \end{aligned}

接下来考虑优化。首先可以对 i 求前缀和,优化最后一个求和号。

s_{u,i}=\sum_{i'=0}^{i}f_{u,i}.

然后把子节点的 s 做前后缀积,优化掉求积号。

pre_{v,i}=\prod_{v'\in son(u),p(v')<v}s_{v',i},\\ suf_{v,i}=\prod_{v'\in son(u),p(v')>v}s_{v',i}. \end{aligned}

最后的转移方程为

f_{u,i+1}=\sum_{v\in son(u)}pre_{v,\min(i,m-i)}\cdot suf_{v,\min(i-1,m-i)}.

需要特判叶子。

1 为根进行 dfs,答案为 f_{1,0}

#include <bits/stdc++.h>
using namespace std;
template <typename T> void rd(T &x) {
    x = 0;
    int f = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = x * 10 + c - 48;
        c = getchar();
    }
    x *= f;
}
template <typename T, typename... T2> void rd(T &x, T2 &...y) {
    rd(x), rd(y...);
}

typedef long long ll;
const int mod = 998244353;

const int N = 5010, M = N;
int n, m;
vector<int> e[N];

int f[N][M], sum[N][M], pre[M];

void dfs(int u, int fa) {
    if (e[u].size() == !!fa) {
        f[u][0] = sum[u][0] = f[u][1] = 1;
        fill(sum[u] + 1, sum[u] + m + 2, 2);
        return;
    }
    for (int v : e[u]) {
        if (v == fa) continue;
        dfs(v, u);
    }
    for (int t = 0; t <= 1; ++t) {
        fill(pre, pre + m + 1, 1);
        bool fl = false;
        for (int v : e[u]) {
            if (v == fa) continue;
            if (t && fl) f[v][0] = 0;
            for (int i = t; i <= m; ++i) {
                f[v][i] = f[v][i] * (ll)pre[min(i - t, m - i)] % mod;
            }
            for (int i = 0; i <= m; ++i)
                pre[i] = pre[i] * (ll)sum[v][i] % mod;
            fl = true;
        }
        reverse(e[u].begin(), e[u].end());
    }
    for (int i = 0; i <= m; ++i) {
        for (int v : e[u]) {
            if (v == fa) continue;
            (f[u][i + 1] += f[v][i]) %= mod;
        }
        (f[u][0] += f[u][i + 1]) %= mod;
    }
    sum[u][0] = f[u][0];
    for (int i = 0; i <= m; ++i) {
        (sum[u][i + 1] = sum[u][i] + f[u][i + 1]) %= mod;
    }
}

int main() {
    rd(n, m);
    for (int i = 1; i < n; ++i) {
        int u, v;
        rd(u, v);
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1, 0);
    printf("%d", sum[1][0]);
}