CF1615E 题解

· · 题解

赛时在 C 挣扎太久了导致这道题没写……

purple crayon 不是出题人吗(

题读错还在瞎编的人是屑/kk

红蓝双方在以 1 为根的树上涂色,红先手。每次一方可以涂一个子树为自己颜色,并可以涂任意次。但红方最多只能涂 k 个节点,且每次任意一方要涂的子树内不能有对方颜色。设最后有 w 个白色节点,r 个红色节点,b 个蓝色节点,则分数为 w(r-b)

红方想最大化分数,蓝方想最小化分数,假设两方都绝顶聪明,求最后的分数。

结论 1:红方涂一个节点 x,相当于阻止蓝方在 x 到节点 1 的路径上涂色。

只有这条路径上的节点的子树包含了 x

结论 2(关键结论):红方涂节点越多,能控制的节点越多,蓝方能涂的节点越少,分数不会变小。

注意到 w+r+b=n,令 f(x)=x(n-x),则分数可以化简为 w(r-b)=(n-r-b)(r-b)=r(n-r)-b(n-b)=f(r)-f(b)。注意到对于 f(x),在 x=\left\lfloor n\over 2\right\rfloor 时取到最大。

R 为红色能控制的节点数,则 b\le n - R,则 R 越大 \max f(b) 越小。

\begin{aligned} g(r+1)&=f(r+1)-f(R(r+1))\\ &\ge f(r+1)-f(R(r)+1)\\ &=f(r)-f(R(r))+(-2r+n+1-(-2R(r)+n+1))\\ &\ge f(r)-f(R(r))=g(r) \end{aligned}

综上,分数随 r 递增。

结论 3:红方会优先涂叶子节点。

根据结论 2,如果我们将涂色的节点换成其中一个儿子,那么能控制的节点数不会减少,类推可得。

所以,红方只需要每次取叶子节点,使增加的控制的节点数尽可能多即可。注意到这一过程类似长链剖分,所以我们只需要跑长链,然后根据链长倒序取即可。

有一个细节:如果叶子节点数 \le k,即红色能覆盖所有节点时,可以不必覆盖所有节点,而让 r 尽可能接近 n\over 2 是最优的。

代码:

#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;
const int max_n = 200000;

int hd[max_n], des[max_n<<1], nxt[max_n<<1], dep[max_n], dd[max_n], ln[max_n], e_cnt = 0, lac = 0;

void add(int s, int t)
{
    des[e_cnt] = t;
    nxt[e_cnt] = hd[s];
    hd[s] = e_cnt++;
}

inline void chmax(int& a, int b) { if (a < b) a = b; }
void dfs(int id, int fa)
{
    for (int p = hd[id]; p != -1; p = nxt[p])
        if (des[p] != fa)
        {
            dd[des[p]] = dd[id] + 1;
            dfs(des[p], id);
            chmax(dep[id], dep[des[p]]);
        }
    dep[id]++;
}

void dfs2(int id, int top)
{
    if (nxt[hd[id]] == -1 && id)
    {
        ln[lac++] = dd[id] - dd[top] + 1;
        return;
    }

    int mx = -1, mxp;
    for (int p = hd[id]; p != -1; p = nxt[p])
        if (dep[des[p]] < dep[id] && dep[des[p]] > mx)
            mx = dep[des[p]], mxp = des[p];

    dfs2(mxp, top);
    for (int p = hd[id]; p != -1; p = nxt[p])
        if (dep[des[p]] < dep[id] && des[p] != mxp)
            dfs2(des[p], des[p]);
}

signed main()
{
    memset(hd, -1, sizeof hd);

    ios_base::sync_with_stdio(false);
    cin.tie(0);

    int n, k;

    cin >> n >> k;
    for (int i = 1, x, y; i < n; i++)
    {
        cin >> x >> y;
        add(x-1, y-1), add(y-1, x-1);
    }
    dfs(0, -1), dfs2(0, 0);

    if (k >= lac)
    {
        if (k > n / 2 && lac <= n / 2) k = n / 2;
        else if (lac > n / 2) k = lac;
        cout << 1ll * k * (n - k) << endl;
    }
    else
    {
        sort(ln, ln + lac, [](int a, int b) { return a > b; });

        int pn = n;
        for (int i = 0; i < k; i++)
            pn -= ln[i];

        if (pn > n / 2) pn = n / 2;
        cout << 1ll * k * (n - k) - 1ll * pn * (n - pn) << endl;
    }

    return 0;
}