题解:P12504 「ROI 2025 Day1」树上的青蛙

· · 题解

P12504 「ROI 2025 Day1」树上的青蛙

简化题意

给定一棵 n 个节点的树,dis_{u,v} 表示树上两点的距离。

再给定一个整数 D

初始时有一个空的二元组集合 S

每次你可以选出两个节点 u,v,满足如下条件:

如果满足上述条件,则可以将 (u,v) 放入 S

最大化 |S|,并输出此时 S 内的元素(如有多个 S 满足要求,考虑任意一种均可)。

题解

知识点:dsu on tree,贪心。

质量很高的题目,断断续续大战了一天。

n\le 14

直接状压/暴搜即可,状压更好写。

思考 1

## $D=1

此时相邻的才能选。

考虑树形 dp,每个节点选或者不选,从儿子转移,记录一下最优情况选了谁就行了。

D\le 200

将节点按深度奇偶性分组,给距离为奇数且 \le D 的点对在新图连边,显然是一个二分图,直接跑二分图最大匹配。

思考 2

维护每个节点维护维护两个 set 表示子树内的节点奇/偶数深度集合,从叶子开始,往上在祖先处启发式合并。

思考 3

如果 D\bmod 2=0,那么 D\leftarrow D-1,方便处理。

可以从下往上贪心,两个 set 在节点 u 启发式合并时,所模拟的路径是从其中一个 set 里的节点走到 u,再走到另一个 set 里的节点,如果路程为 D 或者当前在根节点,那就贪心计入答案。

考虑这么贪心为什么是对的,一对点匹配的贡献都是相同的 1 是前提,而贪心是有顺序的从下往上,显然越晚(路程越靠近 D)配对越好,这说明 u 以下的节点不能让他们这样走的路程为 D,即使往上走还可能出现距离刚好为 D 的点,实际上还不如偏安一隅,往上走浪费的不仅是自己的机会,还有别的点的机会,下面也没有能和你配的了,还能怎么办不必多说。

复杂度可以做到 O(nD\log n+n \log^2 n),实际情况很难跑满,实测可以获得 64 分,如果是赛时我想到这里应该就润了

#include<bits/stdc++.h>
using namespace std;

#define rep(i,l,r) for(int i=(l);i<=(r);++i)
#define per(i,l,r) for(int i=(r);i>=(l);--i)
#define pr pair<int,int>
#define fi first
#define se second
#define pb push_back
#define all(x) (x).begin(),(x).end()
#define sz(x) (int)(x).size()
#define bg(x) (x).begin()
#define ed(x) (x).end()

#define N 502507
// #define int long long

int n,D,rt,d[N];
vector<int>e[N];
vector<pr>ans;
set<pr>s[N][2];

inline void mg(set<pr>&a,set<pr>&b,int du){
    if(sz(a)<sz(b)){
        swap(a,b);
    }

    for(pr u:b){
        if(-u.fi-du>D){
            continue;
        }
        a.insert(u);
    }
    b.clear();
}

inline void calc(set<pr>&a,set<pr>&b,int du,bool rt){
    vector<pr>del;

    if(sz(a)<sz(b)){
        for(pr u:a){
            auto it=b.lower_bound({d[u.se]-2*du-D,0});

            if(-u.fi-du>D){
                del.pb(u);
                continue;
            }

            if(it==b.end()){
                continue;
            }

            int dv=-it->fi;

            if(d[u.se]+dv-2*du!=D&&!rt){
                continue;
            }

            ans.pb({u.se,it->se});
            del.pb(u);b.erase(it);
        }

        for(pr u:del){
            a.erase(u);
        }
    }
    else{
        for(pr u:b){
            auto it=a.lower_bound({d[u.se]-2*du-D,0});

            if(-u.fi-du>D){
                del.pb(u);
                continue;
            }

            if(it==a.end()){
                continue;
            }

            int dv=-it->fi;

            if(d[u.se]+dv-2*du!=D&&!rt){
                continue;
            }

            ans.pb({u.se,it->se});
            del.pb(u);a.erase(it);
        }

        for(pr u:del){
            b.erase(u);
        }
    }
}

inline void dfs(int k,int fa){
    d[k]=d[fa]+1;
    s[k][d[k]&1].insert({-d[k],k});

    for(int x:e[k]){
        if(x==fa){
            continue;
        }

        dfs(x,k);
    }

    for(int x:e[k]){
        if(x==fa){
            continue;
        }

        mg(s[k][0],s[x][0],d[k]);
        mg(s[k][1],s[x][1],d[k]);
    }

    calc(s[k][0],s[k][1],d[k],k==rt);
}

signed main(){
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);

    cin>>n>>D;

    if(D%2==0){
        D--;
    }

    rep(i,1,n-1){
        int u,v;
        cin>>u>>v;
        e[u].pb(v);
        e[v].pb(u);
    }

    mt19937 rd(114514);
    rt=(rd())%n+1;

    dfs(rt,0);

    // cout<<"root :"<<rt<<"\n";

    cout<<sz(ans)<<"\n";
    for(pr u:ans){
        cout<<u.fi<<' '<<u.se<<"\n";
    }

    return 0;
}

思考 4

这种合并的方式非常精巧,考虑沿用这种方式。

考虑去掉复杂度里的 D

思考上面的算法慢在了哪里,跑满 O(nD\log n) 的情况是带着一堆根本匹配不上的节点一直往上跑,拖慢了速度,如果能每次都可以精准匹配就好了。

思考 5

set 换成 mapvector,可以储存当前每种深度有哪些节点,且能精确访问。

当合并入一个节点的时候,记其深度为 d_1,当前所在的节点的深度为 d_u,则在 map 找到最大的满足 d_1+d_2-2d_u\le D 且奇偶性与 d_1 相反的 d_2,计算出他们可以完美合并(路程刚好为 D)的时候的深度 md=\frac{d_1+d_2-D}{2},然后等到深度为 md 的时候再看能不能计入答案。

思考 6

考虑开一个优先队列维护这个过程,以 md 为第一关键字从大到小处理。

当前节点为 u 时,如果当前满足 md=d_u,则进行处理,看当前的 map 是否有 d_1d_2,有就说明能配对,分别取出 d_1d_2 代表的 vector 的一项,计入答案,然后删除。

这时候判断 vector 是否为空,如果空了就从 map 里删掉,表示当前该不存在该深度的节点。

然后考虑拓展,分别找到 map 中最大的满足奇偶性相同,且小于等于 d_1d_2 的深度去拓展,因为此时 map 可能不存在 d_1 或者 d_2 的深度了。

如果 md<d_u,直接退出,留给深度更小的节点去处理。

特殊地,如果 u 是根节点,则留给深度更小的节点去处理是没有意义的,可以直接全处理了。

思考 7

每次合并都要加入所有子节点 map 数组的所有元素,这样跑下来,虽然能精确地匹配,但是太慢了,时间复杂度是 O(n^2 \log n)

考虑 dsu on tree,对于每个节点的,直接继承重儿子的信息,然后再一一和轻儿子合并,这样复杂度就变成了 O(n\log^2 n)

你问为什么 O(n \log^2 n)5\times 10^5 只跑了 1 秒左右?因为 dsu on tree 根本跑不满,特别是随机了一个根的情况下。

代码细节比较多,有很多边界要考虑,比较考验调试能力。

#include<bits/stdc++.h>
using namespace std;

#define rep(i,l,r) for(int i=(l);i<=(r);++i)
#define per(i,l,r) for(int i=(r);i>=(l);--i)
#define pr pair<int,int>
#define fi first
#define se second
#define pb push_back
#define all(x) (x).begin(),(x).end()
#define sz(x) (int)(x).size()
#define bg(x) (x).begin()
#define ed(x) (x).end()

#define N 502507
// #define int long long

int n,D,d[N],siz[N],son[N],rt;
vector<int>e[N];
vector<pr>ans;

inline void dfs(int k,int fa){
    d[k]=d[fa]+1;
    siz[k]=1;

    for(int x:e[k]){
        if(x==fa){
            continue;
        }

        dfs(x,k);
        siz[k]+=siz[x];

        if(siz[x]>siz[son[k]]){
            son[k]=x;
        }
    }
}

struct myds{
    int du;

    map<int,vector<int>>s[2];
    priority_queue<pr>q;

    inline void clr(){
        s[0].clear();
        s[1].clear();
        while(sz(q)){
            q.pop();
        }
    }

    inline void cmp(int d1){
        int ip=!(d1&1);

        auto it=s[ip].upper_bound(2*du+D-d1);

        if(it==s[ip].begin()||!s[d1&1].count(d1)){
            return;
        }

        it--;
        int d2=it->fi,md=(d1+d2-D)/2;

        q.push({md,d1});
    }

    inline void add(int k){
        int pos=d[k]&1;

        if(s[pos].count(d[k])){
            s[pos][d[k]].pb(k);
        }
        else{
            s[pos][d[k]].pb(k);
            cmp(d[k]);
        }
    }

    inline void run(bool rt){
        while(sz(q)){
            pr u=q.top();

            if(du>u.fi&&!rt){
                break;
            }

            q.pop();

            int d1=u.se,d2=D+u.fi*2-d1;
            bool i1=d1&1,i2=d2&1;

            if(!s[i1].count(d1)){
                continue;
            }
            if(!s[i2].count(d2)){
                continue;
            }

            int x=s[i1][d1].back();
            s[i1][d1].pop_back();

            int y=s[i2][d2].back();
            s[i2][d2].pop_back();

            ans.pb({x,y});

            // cout<<d1<<' '<<d2<<" cmped "<<x<<' '<<y<<" at dep"<<u.fi<<"\n";

            if(s[i1][d1].empty()){
                s[i1].erase(d1);
            }

            auto it=s[i1].upper_bound(d1);
            if(it!=s[i1].begin()){
                it--;
                cmp(it->fi);
            }

            if(s[i2][d2].empty()){
                s[i2].erase(d2);
            }

            it=s[i2].upper_bound(d2);
            if(it!=s[i2].begin()){
                it--;
                cmp(it->fi);
            }
        }
    }
}a[N];

inline void sol(int k,int fa){
    // cout<<k<<" st\n";

    if(!son[k]){
        a[k].du=d[k];
        a[k].add(k);
        return;
    }

    sol(son[k],k);
    swap(a[k],a[son[k]]);
    a[k].du=d[k];

    a[k].add(k);

    // cout<<k<<" hson\n";

    while(sz(a[k].s[0])){
        auto it=a[k].s[0].end();
        it--;

        if(it->fi>a[k].du+D){
            a[k].s[0].erase(it);
        }
        else{
            break;
        }
    }
    while(sz(a[k].s[1])){
        auto it=a[k].s[1].end();
        it--;

        if(it->fi>a[k].du+D){
            a[k].s[1].erase(it);
        }
        else{
            break;
        }
    }

    // cout<<k<<" pop\n";

    for(int x:e[k]){
        if(x==fa||x==son[k]){
            continue;
        }

        sol(x,k);

        // cout<<k<<" mg "<<x<<"\n";

        for(auto v:a[x].s[0]){
            for(int x:v.se){
                // cout<<k<<" add "<<x<<"\n";
                a[k].add(x);
            }
        }

        for(auto v:a[x].s[1]){
            for(int x:v.se){
                // cout<<k<<" add "<<x<<"\n";
                a[k].add(x);
            }
        }

        a[x].clr();
    }

    a[k].run(k==rt);

    // cout<<k<<" run\n";
}

signed main(){
    // freopen(".in","r",stdin);
    // freopen(".out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);

    cin>>n>>D;

    if(D%2==0){
        D--;
    }

    rep(i,1,n-1){
        int u,v;
        cin>>u>>v;
        e[u].pb(v);
        e[v].pb(u);
    }

    mt19937 rd(114514);
    rt=(rd())%n+1;

    // cout<<"root"<<rt<<"\n";

    dfs(rt,0);
    sol(rt,0);

    cout<<sz(ans)<<"\n";
    for(pr u:ans){
        cout<<u.fi<<' '<<u.se<<"\n";
    }

    return 0;
}