题解:P9962 [THUPC 2024 初赛] 一棵树

· · 题解

来点魔怔做法。

首先根据官方题解的思路,考虑树上背包,f_{i,j} 代表 i 的子树内,选了 j 个黑点的最小权值,转移就相当于先把所有儿子的背包合并起来然后再给每一位加上 |k-2j|

发现这个东西具有下凸性,可以通过归纳法证明,那么我们就可以用带懒标记的可合并的堆维护凸包的斜率,具体的操作为,对每个结点先往堆里加入一个 0,然后把所有儿子的堆也合并进来,然后前 \lfloor\frac{k}{2}\rfloor 小的斜率减 2,如果 k 为奇数,那么第 \frac{k+1}{2} 斜率不变,剩下的斜率全部加 2

写左偏树维护前 \lfloor\frac{k}{2}\rfloor 个数应该就做完了,但是通过观察这个操作的性质,我们可以找到更优秀的做法。我们发现最后根节点的堆里面的每个数都是由在某个节点加入的 0 转变而来,那我们就可以反过来考虑各个节点加入的 0 都会经历怎样的操作。

发现一个点被弹出堆的位置就是最近的子树中有 \lfloor\frac{k}{2}\rfloor 个节点深度比他大的祖先,于是我们可以按深度从小到大将节点排序,然后依次删除,同时找到第一个子树中还剩 \lfloor\frac{k}{2}\rfloor 以上个节点未被删的祖先,用 dfn 序和树状数组维护子树内未被删节点个数,再用并查集维护最近满足条件的节点即可。如果 k 为偶数,那么就可以直接求出这个 0 到根节点时的值。若 k 为奇数,那么再 dfs 一遍,每个点上维护已经被弹出的节点中的最小值即可。最后到根节点再用 k\times(n-1) 减去最小的 k 个斜率即可。

细节可以看代码。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5e5+5;
int n,k;
vector<int>vec[N];
int dep[N],f[N],fa[N],id[N],ed[N],dfn[N],dfcnt;
vector<int>e[N];
void dfs(int u)
{
    dep[u] = dep[f[u]] + 1;
    dfn[u] = ++dfcnt;
    for(auto v:e[u]) if(v != f[u])
    {
        f[v] = u;
        dfs(v);
    }
    ed[u] = dfcnt;
}
bool cmp(int i,int j){return dep[i] < dep[j];}
int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
int tr[N];
int lowbit(int c){return c & (-c);}
void upd(int c,int x){for(;c <= n;c += lowbit(c)) tr[c] += x;}
int qry(int c){int Res = 0;for(;c;c -= lowbit(c)) Res += tr[c];return Res;}
vector<int>res;
int mn[N];
void dfs1(int u)
{
    mn[u] = 1;
    for(auto x:vec[u])
    {
        if(x < mn[u])
        {
            if(mn[u] != 1) res.push_back(mn[u]+2*(dep[u]-1));
            mn[u] = x;
        }
        else
        {
            res.push_back(x+2*(dep[u]-1));
        }
    }
    for(auto v:e[u]) if(v != f[u])
    {
        dfs1(v);
        if(mn[v] == 1) continue;
        int x = mn[v];
        if(x < mn[u])
        {
            if(mn[u] != 1) res.push_back(mn[u]+2*(dep[u]-1));
            mn[u] = x;
        }
        else
        {
            res.push_back(x+2*(dep[u]-1));
        }
    }
}
int main()
{
    ios::sync_with_stdio(false);cin.tie(0);
    cin>>n>>k;
    for(int i = 1;i <= n;i++) id[i] = i;
    for(int i = 1,u,v;i < n;i++)
    {
        cin>>u>>v;
        e[u].push_back(v),e[v].push_back(u);
    }
    dfs(1);
    sort(id+1,id+1+n,cmp);
    for(int i = 1;i <= n;i++) upd(i,1),fa[i] = i;
    for(int i = 1;i <= n;i++)
    {
        int u = id[i];
        u = find(u);
        //if(u%10000 == 0) cerr<<"!"<<u<<"\n";
        while(u != 1)
        {
            if(qry(ed[u]) - qry(dfn[u]-1) > k/2) break;
            fa[u] = f[u];
            u = find(u);
        }
        upd(dfn[u],-1);
        vec[u].push_back(-2*(dep[id[i]]-dep[u]));
    }
    if(k%2 == 0)
    {
        for(int i = 1;i <= n;i++) for(auto x:vec[i]) res.push_back(x+2*(dep[i]-1));
        sort(res.begin(),res.end());
        ll ans = 1ll*(n-1)*k;
        for(int i = 0;i < k;i++) ans += res[i];
        cout<<ans<<"\n";
        return 0;
    }
    ll ans = 1ll*(n-1)*k;
    dfs1(1);
    res.push_back(mn[1]);
    sort(res.begin(),res.end());
    for(int i = 0;i < k;i++) ans += res[i];
    cout<<ans<<"\n";
    return 0;
}