P9056 题解
DaiRuiChen007 · · 题解
Problem Link
题目大意
给定一棵
n 个点的树,你可以进行5\times 10^5 次询问:每次询问给定三个点x,y,z ,如果三个点不在一条链上,返回0 ,否则返回链中间的节点。请求出树的重心。数据范围:
n\le 3\times 10^4 。
思路分析
先考虑链的情况,首先求出区间的两个端点是很容易的,维护
然后可以考虑随机二分,每次随机找一个点
事实上,随机二分的过程中,每次期望删掉
然后考虑原问题。
尝试把原问题转成链的情况,我们可以随机两个点
如果重心在这条链上,那么链上该点左侧、右侧的点的子树大小和均
依然考虑在
如果进行带权二分,即在
设随机出的点是
然后求出
这样做,询问总点数的期望为依然为
最后我们要判断求出的点是否是重心。
考虑把
为了优化常数,我们可以在递归时特殊维护两个端点的子树。
询问次数:不知道,反正能过。
时间复杂度
代码呈现
#include<bits/stdc++.h>
using namespace std;
extern "C" int ask(int x,int y,int z);
mt19937 rnd;
typedef vector<int> vi;
int n;
int chk(int rt,vi V) {
int cnt=0,col=0;
for(int i:V) {
if(!cnt) ++cnt,col=i;
else {
if(ask(rt,i,col)==rt) --cnt;
else ++cnt;
}
}
if(!cnt) return rt;
int tot=0;
for(int i:V) if(ask(rt,col,i)!=rt) ++tot;
return (tot<=n/2)?rt:-1;
}
int solve(int l,int r,vi L,vi R,vi C,vi S) {
if((int)L.size()>n/2) return chk(l,L);
if((int)R.size()>n/2) return chk(r,R);
if(l==r) return chk(l,S);
if(C.size()==2) {
vi tmp=C; tmp.push_back(r);
for(int i:R) tmp.push_back(i);
int z=chk(l,tmp);
if(~z) return z;
tmp=C,tmp.push_back(l);
for(int i:L) tmp.push_back(i);
return chk(r,tmp);
}
int k=l;
while(k==l||k==r) {
int rd=rnd()%(C.size()+S.size());
if(rd<(int)C.size()) k=C[rd];
else {
int t=S[rd-C.size()];
for(int i:C) if(ask(k,t,i)==i) k=i;
}
}
vector <int> CL,CR,SK,SL,SR;
for(int i:C) if(i!=k) {
if(ask(l,i,k)==i) CL.push_back(i);
else CR.push_back(i);
}
for(int i:S) {
if(ask(l,k,i)==k) {
if(ask(r,k,i)==k) SK.push_back(i);
else SR.push_back(i);
} else SL.push_back(i);
}
int sl=L.size()+CL.size()+SL.size(),sr=R.size()+CR.size()+SR.size();
if(sl<=n/2&&sr<=n/2) return chk(k,SK);
if(sr>n/2) swap(L,R),swap(CL,CR),swap(SL,SR),swap(l,r);
r=l;
R.push_back(k);
for(int i:SK) R.push_back(i);
for(int i:CR) R.push_back(i);
for(int i:SR) R.push_back(i);
for(int i:CL) if(i!=l&&ask(i,r,l)==r) r=i;
S.clear();
for(int i:SL) {
if(ask(l,r,i)==r) R.push_back(i);
else S.push_back(i);
}
return solve(l,r,L,R,CL,S);
}
extern "C" int centroid(int id,int N,int M) {
n=N;
if(id==1) return ask(1,2,3);
if(id==3||id==5) {
int l=1,r=2;
for(int i=3;i<=N;++i) {
int u=ask(l,r,i);
if(l==u) l=i;
if(r==u) r=i;
}
vi arr(N);
iota(arr.begin(),arr.end(),1);
nth_element(arr.begin(),arr.begin()+N/2,arr.end(),[&](int x,int y) {
return ask(x,y,r)==y;
});
return arr[N/2];
}
while(true) {
int l=rnd()%n+1,r=rnd()%n+1;
if(l==r) continue;
vi L,R,C{l,r},S;
for(int i=1;i<=n;++i) if(i!=l&&i!=r) {
int z=ask(i,l,r);
if(z==l) L.push_back(i);
if(z==r) R.push_back(i);
if(z==i) C.push_back(i);
if(z==0) S.push_back(i);
}
int k=solve(l,r,L,R,C,S);
if(~k) return k;
}
return 0;
}