题解:P14020 [ICPC 2024 Nanjing R] 二叉树

· · 题解

题意应该不用讲了。

首先你知道这棵树的形态,因为题目要求我们在 \log 次内找出关键点,所以我们每次可以选树的重心。

接下来因为这棵树是二叉树,所以它的每个节点的度数不超过 3,所以我们可以选出它的最大的两个子节点进行查询,如果它们两个中有一个离关键点最近,那么就把它和重心的边割掉,用它的子树继续递归,否则就把这两个子节点都割掉,继续递归。

最后判一下节点数为 1 或 2 的就行了。

ps:为什么是选最大的两个子节点询问?

考虑这个数据:

6
4 3
0 1
0 0
5 6
0 0
0 0

如果你选了 1 为重心,并且查了 2 和 3。如果交互库回答 1。那么就剩下 1、4、5、6 这些节点,然后你又要花 2 次才能查到,总共花了 3 次。但是如果你一开始查了 2 和 4,那么你接下来查一次就可以查到了。

最后你还要注意空间,不知道为什么我的黑子团在这题上 MLE 并且吃了十几发罚时。

代码:

#include<bits/stdc++.h>
using namespace std;
#define N 100010
int T,n,siz[N],heavy,num;
bool vis[N];
vector<int>to[N];
void init(int x,int fa){
    num++,siz[x]=1;
    for(auto y:to[x]){
        if(y==fa){continue;}
        if(vis[y]){continue;}
        init(y,x);
        siz[x]+=siz[y];
    }
}
void dfs(int x,int fa){
    int maxn=0;
    for(auto y:to[x]){
        if(y==fa){continue;}
        if(vis[y]){continue;}
        maxn=max(maxn,siz[y]);
        dfs(y,x);
    }
    maxn=max(maxn,num-siz[x]);
    if(maxn<=num/2){heavy=x;}
}
signed main(){
    cin>>T;
    while(T--){
        cin>>n;
        for(int i=1;i<=n;i++){to[i].clear(),vis[i]=0;}
        for(int i=1;i<=n;i++){
            int l,r;
            cin>>l>>r;
            if(l){to[i].push_back(l),to[l].push_back(i);}
            if(r){to[i].push_back(r),to[r].push_back(i);}
        }
        int rt=1;
        while(1){
            num=0;
            init(rt,0),dfs(rt,0);
            if(num==1){
                cout<<"! "<<rt<<endl;
                break;
            }
            if(num==2){
                int oth=0;
                for(auto y:to[rt]){if(!vis[y]){oth=y;}}
                cout<<"? "<<rt<<' '<<oth<<endl;
                int x;
                cin>>x;
                if(x==0){cout<<"! "<<rt<<endl;}else{cout<<"! "<<oth<<endl;}
                break;
            }
            int u=0,v=0,w=0,szu,szv,szw;
            for(auto x:to[heavy]){if(!vis[x]){if(!u){u=x;}else{if(!v){v=x;}else{w=x;}}}}
            if(siz[u]<siz[heavy]){szu=siz[u];}else{szu=num-siz[heavy];}
            if(siz[v]<siz[heavy]){szv=siz[v];}else{szv=num-siz[heavy];}
            if(w){
                if(siz[w]<siz[heavy]){szw=siz[w];}else{szw=num-siz[heavy];}
                if(szu<=szw){swap(u,w),swap(szu,szw);}
                if(szv<=szw){swap(v,w),swap(szv,szw);}
            }
            cout<<"? "<<u<<' '<<v<<endl;
            int x;
            cin>>x;
            if(x==1){vis[u]=vis[v]=1,rt=heavy;continue;}
            if(x==0){vis[heavy]=1,rt=u;continue;}
            if(x==2){vis[heavy]=1,rt=v;continue;}
        }
    }
    return 0;
}