题解:AT_abc394_f [ABC394F] Alkane

· · 题解

定义

对每个非根节点 vdp_v 表示以 v 为根的子树中,添加 v 后形成烷的最大顶点数。

转移

v 度数大于等于 1,则 dp_v = 1(仅包含 v 和其父节点)。

v 度数大于等于 4,需要 3 个子节点:选 3dp_u 最大的子节点 u_1, u_2, u_3,用 dp_{u_1} + dp_{u_2} + dp_{u_3} + 1 更新。

枚举每个顶点 v 作为最深顶点:若 v 度数为 1,用 dp_u + 1uv 的父节点)更新。

v 度数大于等于 4,用 dp_{u_1} + dp_{u_2} + dp_{u_3} + dp_{u_4} + 1u_1, u_2, u_3, u_4v 的子节点)更新。

取所有情况的最大值。

#include <bits/stdc++.h>
// #include <atcoder/all>
// #define int long long
using namespace std;
inline int read()
{
    int f = 0, ans = 0;
    char c = getchar();
    while (!isdigit(c))
        f |= c == '-', c = getchar();
    while (isdigit(c))
        ans = (ans << 3) + (ans << 1) + c - 48, c = getchar();
    return f ? -ans : ans;
}
void write(int x)
{
    if (x < 0)
        putchar('-'), x = -x;
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}
constexpr int N = 2e5 + 5;
int n, ans, f[N];
vector<int> g[N];
vector<pair<int, int>> sub_f[N];
void dfs(int u, int fa)
{
    f[u] = 1;
    for (int &v : g[u])
        if (v != fa)
        {
            dfs(v, u);
            sub_f[u].emplace_back(v, f[v]);
        }
    sort(begin(sub_f[u]), end(sub_f[u]), [](auto &x, auto &y)
         { return x.second > y.second; });
    if (size(sub_f[u]) >= 3)
    {
        f[u] += f[sub_f[u][0].first] + f[sub_f[u][1].first] + f[sub_f[u][2].first];
        if (fa)
            ans = max(ans, f[u] + 1);
    }
    if (size(sub_f[u]) >= 4)
    {
        ans = max(ans, 1 + f[sub_f[u][0].first] + f[sub_f[u][1].first] + f[sub_f[u][2].first] + f[sub_f[u][3].first]);
    }
}
void resolve(int u, int fa, int from_fa)
{
    if (fa && f[u] != 1)
        ans = max(ans, f[u] + from_fa);
    if (fa)
        sub_f[u].emplace_back(0, from_fa);
    sort(begin(sub_f[u]), end(sub_f[u]), [](auto &x, auto &y)
         { return x.second > y.second; });
    for (int &v : g[u])
        if (v != fa)
        {
            int sum = 0, cnt = 0;
            for (auto &[i, w] : sub_f[u])
                if (i != v)
                {
                    sum += max(f[i], w);
                    if (++cnt == 3)
                        break;
                }
            if (cnt == 3)
                resolve(v, u, sum + 1);
        }
}
signed main()
{
    // freopen("in.in", "r", stdin);
    // freopen(".out", "w", stdout);
    cin >> n;
    for (int i = 1; i < n; ++i)
    {
        int u, v;
        cin >> u >> v;
        g[u].emplace_back(v);
        g[v].emplace_back(u);
    }
    ans = -1;
    dfs(1, 0);
    resolve(1, 0, 0);
    write(ans);
    return 0;
}