题解 CF1761G 【Centroid Guess】

· · 题解

这里是官方题解,CF 官方英文题解由此题解翻译得到。

假设我们现在确定了树的重心在 uv 的链上,考虑如何找到它。

uv 的链经过的所有点依次为 c_1,c_2,\dots,c_k,我们把由链上一点 c_i 出发不经过链上其他点可到达的点的集合称为 A_i,显然 A_1,A_2,\dots,A_kk 个集合不重不漏地包含了树上所有的 n 个点。

设集合 A_i 的大小为 s_i,链上显然存在恰好一个点 c_x 满足 \max\{\sum\limits_{i=1}^{x-1}s_i,\sum\limits_{i=x+1}^ks_i\}\le\lfloor\frac{n}{2}\rfloor,注意到不满足该条件的链上的点一定不是重心,而目前已经确定重心在链上,所以此时 c_x 就是树的重心。

接下来考虑通过 2n 次询问得到树上每个点所属集合,进而求出 s。对于一个点 x,询问 dis_{u,x}dis_{v,x},容易发现对于同个集合中的点,dis_{u,x}-dis_{v,x} 的值全相同,若 dis_{u,x}-dis_{v,x}=t_1dis_{u,v}=t_2,则 x\in A_{(t_1+t_2)/2+1}。于是我们求出了树上每个点所属集合,也求出了 s,找到了重心。这最多需要消耗 7.5\cdot10^4\times2=1.5\cdot10^5 次询问。

现在问题是找到点对 u,v,使得树的重心在 uv 的链上。

我们取一个常数 M,使得 \frac{M(M-1)}{2}\le5\cdot10^4,然后在树上随机选取 M 个点,并询问两两间距离,这需要消耗 \frac{M(M-1)}{2}\le5\cdot10^4 次询问。

设这 M 个点分别为 p_1,p_2,\dots,p_M,我们建出包含 LCA 的以 p_1 为根的虚树。容易发现 dis_{x,y}+dis_{y,z}=dis_{x,z},当且仅当 yxz 的链上。对于一个点 p_x,我们可以找到 p_1p_x 的链上的所有点,找到这些点中除自己外与 p_x 最近的点与 p_x 连边,它就是点集中 p_x 的深度最大的祖先。

此时我们求出了不包含 LCA 的以 p_1 为根的虚树,接下来我们将添加 LCA 至虚树中。设包含 LCA 的虚树为 T

我们从 p_1 开始 dfs,设当前节点为 u,显然 u 的儿子之间不构成祖先后代关系,我们枚举 u 的儿子 v,若另一个儿子 x 满足 u 不在原树中 xv 的链上,则在 Tvxu 的同一个子树内,找到所有在 T 中与 vu 的同一个子树内的点后,容易计算它们 LCA 的深度,并处理出 LCA 到当前虚树所有点的距离,接着断开原来的边,从 LCA 向 T 中与 vu 的同一个子树内的点连边,然后从 u 向 LCA 连边,然后对 u 重复上述过程,直到 u 的儿子中任意两个点都不在 Tu 的同一个子树内,然后向目前 u 的儿子 dfs。不难发现,完成 dfs 后,我们就得到了完整的虚树。

仅仅对于原始随机到的 M 个点,我们认为它的点权是 1,虚树中其他点点权均为 0,我们找到虚树的带权重心(有多个时可任意取其中一个),并以它作为虚树的根,选取它最大的两个子树进行 dfs,每次都走最大的子树,这样得到两个叶子,称其为 u,v,原树的重心有极高的几率在 uv 的链上。然后套用之前的算法即可求得重心。两部分总询问次数相加不超过 2\cdot10^5

正确性证明:

若原树的重心不在 uv 的链上,那么设虚树的重心在原树重心的子树 E 中,若除去 E 外,原树的重心的其他子树包含了至少 1/3 的我们随机出来的节点,那么原树重心必在 uv 的链上,也就是说 E 中至少包含了 2/3 的我们随机出来的节点,又由于 E 的大小不超过 \lfloor\frac{n}{2}\rfloor,所以随机一个点不在 E 中的概率至少为 \frac{1}{2},随机 M 个点,至多有 1/3 的不在 E 中,E 才能包含 2/3 的我们随机出来的节点,也就是,随机 M 个点,每个点超过 1/2 的概率不合法,至多只能有 1/3 的点不合法。只有这样才会使得该算法出错。这个概率不高于 \sum\limits_{i=0}^{M/3}C_M^i/2^M,取 M=316,约为 6\cdot10^{-10}

fun fact:

以下是 std(Main Correct Solution):

#include<bits/stdc++.h>
using namespace std;
#define N 75005
#define S 316
#define M 640
int read(){
    int w=0;
    char c=getchar();
    while(c<'0'||c>'9')c=getchar();
    while(c>='0'&&c<='9')w=w*10+c-48,c=getchar();
    return w;
}
void report(int x){
    printf("! %d\n",x);
    exit(0);
}
const int lim=2e5;
int n,m,p[N],dis[M][M],fa[M],dep[M],d1[N],d2[N],siz[M],maxsiz[M],r[N],w[N],du,dv,idx;
bool vis[M];
vector<int>e[M],g[M];
mt19937 rnd(time(NULL)^(unsigned long long)(new char));
int rad(int x,int y){
    return rnd()%(y-x+1)+x;
}
bool in(int x,int y,int z){
    return dis[x][y]+dis[y][z]==dis[x][z];
}
void get(int u){
    vis[u]=1;
    for(auto v:e[u])
        get(v);
}
bool work(int u){
    for(auto v1:e[u]){
        int l=0;
        for(auto v2:e[u])
            if(!in(v1,u,v2))l=max(l,(dis[v1][u]+dis[v1][v2]-dis[u][v2])/2);
        if(!l)continue;
        l=dis[u][v1]-l;
        for(auto v2:e[u])
            if(!in(v1,u,v2))get(v2);
        idx++;
        for(int i=1;i<idx;i++)
            dis[i][idx]=dis[idx][i]=dis[u][i]+(vis[i]?-l:l);
        memset(vis+1,0,sizeof(bool)*idx);
        vector<int>x;
        for(auto v2:e[u])
            if(!in(v1,u,v2))e[idx].push_back(v2);
            else x.push_back(v2);
        return e[u]=x,e[u].push_back(idx),1;
    }
    return 0;
}
void dfs1(int u){
    while(work(u));
    for(auto v:e[u])
        dfs1(v),g[v].push_back(u);
}
void dfs2(int u,int ft){
    siz[u]=u<=m?1:0;
    for(auto v:e[u])
        if(v!=ft)dfs2(v,u),siz[u]+=siz[v],maxsiz[u]=max(maxsiz[u],siz[v]);
    maxsiz[u]=max(maxsiz[u],m-siz[u]);
}
void dfs3(int u,int ft){
    int sn=0;
    for(auto v:e[u])
        if(v!=ft&&siz[v]>siz[sn])sn=v;
    if(sn)dfs3(sn,u);
    else dv=du?p[u]:0,du=du?du:p[u];
}
signed main(){
    n=read(),m=min(n,S),idx=m;
    for(int i=1;i<=n;i++)
        p[i]=i,swap(p[i],p[rad(1,i)]);
    for(int i=1;i<=m;i++)
        for(int j=i+1;j<=m;j++)
            printf("? %d %d\n",p[i],p[j]);
    fflush(stdout);
    for(int i=1;i<=m;i++)
        for(int j=i+1;j<=m;j++)
            dis[i][j]=dis[j][i]=read();
    for(int i=2;i<=m;i++){
        int maxn=0,maxid=1;
        for(int j=2;j<=m;j++)
            if(i!=j&&in(1,j,i)&&dis[1][j]>maxn)maxn=dis[1][j],maxid=j;
        e[maxid].push_back(i);
    }
    dfs1(1);
    for(int i=1;i<=idx;i++)
        for(auto j:g[i])
            e[i].push_back(j);
    dfs2(1,0);
    int minn=m+1,minid=0;
    for(int i=1;i<=idx;i++)
        if(maxsiz[i]<minn)minn=maxsiz[i],minid=i;
    dfs2(minid,0);
    sort(e[minid].begin(),e[minid].end(),[&](int x,int y){return siz[x]>siz[y];});
    dfs3(e[minid][0],minid),dfs3(e[minid][1],minid);
    for(int i=1;i<=n;i++){
        if(du!=i)printf("? %d %d\n",du,i);
        if(dv!=i)printf("? %d %d\n",dv,i);
    }
    fflush(stdout);
    for(int i=1;i<=n;i++)
        d1[i]=du==i?0:read(),d2[i]=dv==i?0:read();
    int len=d1[dv];
    for(int i=1;i<=n;i++)
        if(d1[i]+d2[i]==len)r[d1[i]]=i;
    for(int i=1;i<=n;i++)
        w[(d1[i]-d2[i]+len)>>1]++;
    for(int i=0,now=0;;i++)
        now+=w[i],now*2>n?report(r[i]):void();
    return 0;
}