P9056 题解

· · 题解

Problem Link

题目大意

给定一棵 n 个点的树,你可以进行 5\times 10^5 次询问:每次询问给定三个点 x,y,z,如果三个点不在一条链上,返回 0,否则返回链中间的节点。请求出树的重心。

数据范围:n\le 3\times 10^4

思路分析

先考虑链的情况,首先求出区间的两个端点是很容易的,维护 l,r,每次加入 i 后把 l,r 变成 i,l,r 中不等于 f(i,l,r) 的两个点即可。

然后可以考虑随机二分,每次随机找一个点 u,然后 \mathcal O(n) 判断每个点在 [l,u] 还是 [u,r] 上,只需求判断 f(l,i,u)=u 是否成立,然后递归较大的一边即可。

事实上,随机二分的过程中,每次期望删掉 \dfrac 14 的点,那么询问的总点数期望为 \mathcal O(n) 级别。

然后考虑原问题。

尝试把原问题转成链的情况,我们可以随机两个点 l,r,然后在 l\to r 的链上找重心,由于重心的每个子树大小 \le\dfrac n2,因此重心在 l\to r 的概率不小于 \dfrac 12,期望随机次数为 \mathcal O(1) 级别。

如果重心在这条链上,那么链上该点左侧、右侧的点的子树大小和均 \le \dfrac n2,相当于在链上求带权中位数,那么就可以把树上问题转成链上问题。

依然考虑在 l\to r 的链上随机一个点,但此时重心不在链的中点处,因此删掉点的数量没有保证。

如果进行带权二分,即在 n 个点中随机一个,求出该点在链上哪个点的子树内,此时每次删掉的点期望就有保证了。

设随机出的点是 v,对应链上节点为 u,那么 f(u,i,v)=i 时令 u\gets i,不断进行此操作即可。

然后求出 u 在链上的左右部分 L,R,然后求 L,R,u 的子树,这些都是容易求出的,递归时判断 |L|,|R|\le\dfrac n2 是否成立即可,如果不成立则递归链为 L/R

这样做,询问总点数的期望为依然为 \mathcal O(n) 级别。

最后我们要判断求出的点是否是重心。

考虑把 u 删去后形成的若干子树,如果 u 不是重心,说明这些子树的大小中存在绝对众数,可以考虑摩尔投票法,我们只要快速判断两个点是否属于 u 的同一个子树,这是简单的。

为了优化常数,我们可以在递归时特殊维护两个端点的子树。

询问次数:不知道,反正能过。

时间复杂度 \mathcal O(n)

代码呈现

#include<bits/stdc++.h>
using namespace std;
extern "C" int ask(int x,int y,int z);
mt19937 rnd;
typedef vector<int> vi;
int n;
int chk(int rt,vi V) {
    int cnt=0,col=0;
    for(int i:V) {
        if(!cnt) ++cnt,col=i;
        else {
            if(ask(rt,i,col)==rt) --cnt;
            else ++cnt;
        }
    }
    if(!cnt) return rt;
    int tot=0;
    for(int i:V) if(ask(rt,col,i)!=rt) ++tot;
    return (tot<=n/2)?rt:-1;
}
int solve(int l,int r,vi L,vi R,vi C,vi S) {
    if((int)L.size()>n/2) return chk(l,L);
    if((int)R.size()>n/2) return chk(r,R);
    if(l==r) return chk(l,S);
    if(C.size()==2) {
        vi tmp=C; tmp.push_back(r);
        for(int i:R) tmp.push_back(i);
        int z=chk(l,tmp);
        if(~z) return z;
        tmp=C,tmp.push_back(l);
        for(int i:L) tmp.push_back(i);
        return chk(r,tmp);
    }
    int k=l;
    while(k==l||k==r) {
        int rd=rnd()%(C.size()+S.size());
        if(rd<(int)C.size()) k=C[rd];
        else {
            int t=S[rd-C.size()];
            for(int i:C) if(ask(k,t,i)==i) k=i;
        }
    }
    vector <int> CL,CR,SK,SL,SR;
    for(int i:C) if(i!=k) {
        if(ask(l,i,k)==i) CL.push_back(i);
        else CR.push_back(i);
    }
    for(int i:S) {
        if(ask(l,k,i)==k) {
            if(ask(r,k,i)==k) SK.push_back(i);
            else SR.push_back(i);
        } else SL.push_back(i);
    }
    int sl=L.size()+CL.size()+SL.size(),sr=R.size()+CR.size()+SR.size();
    if(sl<=n/2&&sr<=n/2) return chk(k,SK);
    if(sr>n/2) swap(L,R),swap(CL,CR),swap(SL,SR),swap(l,r);
    r=l;
    R.push_back(k);
    for(int i:SK) R.push_back(i);
    for(int i:CR) R.push_back(i);
    for(int i:SR) R.push_back(i);
    for(int i:CL) if(i!=l&&ask(i,r,l)==r) r=i;
    S.clear();
    for(int i:SL) {
        if(ask(l,r,i)==r) R.push_back(i);
        else S.push_back(i);
    }
    return solve(l,r,L,R,CL,S);
}
extern "C" int centroid(int id,int N,int M) {
    n=N;
    if(id==1) return ask(1,2,3);
    if(id==3||id==5) {
        int l=1,r=2;
        for(int i=3;i<=N;++i) {
            int u=ask(l,r,i);
            if(l==u) l=i;
            if(r==u) r=i;
        }
        vi arr(N);
        iota(arr.begin(),arr.end(),1);
        nth_element(arr.begin(),arr.begin()+N/2,arr.end(),[&](int x,int y) {
            return ask(x,y,r)==y;
        });
        return arr[N/2];
    }
    while(true) {
        int l=rnd()%n+1,r=rnd()%n+1;
        if(l==r) continue;
        vi L,R,C{l,r},S;
        for(int i=1;i<=n;++i) if(i!=l&&i!=r) {
            int z=ask(i,l,r);
            if(z==l) L.push_back(i);
            if(z==r) R.push_back(i);
            if(z==i) C.push_back(i);
            if(z==0) S.push_back(i);
        }
        int k=solve(l,r,L,R,C,S);
        if(~k) return k;
    }
    return 0;
}