题解 P5903 【【模板】树上 k 级祖先】
树上 k 级祖先
P5903 树上 k 级祖先
先讲长链剖分。对于树上的节点,它所保存的与树形态有关的信息有两个,一个是子树大小,一个是向下延伸链长。回顾重链剖分,实际上就是利用了前者,那么长链剖分实际上就是利用了后者,它的“长”儿子实际上是向下延伸链长最大的儿子。
容易发现长链剖分的过程是
考虑倍增求解该问题的过程,我们希望询问做到
一个节点的
k 级祖先所在的长链长大于等于k
证明十分显然。有了这个结论后我们再来看原问题,我们设
这启发我们设计一个算法:
- 长链剖分,对于每条链的链头,设该链长为
l ,我们预处理出这条链的所有节点,以及链头向上l 个节点,容易发现最后只会记录2n 个节点,不影响O(n) 的复杂度。 - 对每个点处理倍增数组,这一步是
O(n\log n) 的。 - 求出
\text{High-bit}(i),i\in[1,n] ,这里是可以做到O(n) 的。 - 对于询问
x 的k 级祖先,先向上跳2^{\text{High-bit(k)}} 级,然后根据\text{dep} 直接判断待求节点在链头上还是下,然后直接跳。
复杂度显然是
\text{Code}
#include<bits/stdc++.h>
#define REG register
#define LL long long
#define UI unsigned int
#define MAXN 500005
using namespace std;
inline int read(){
REG int x(0);
REG char c=getchar();
while(!isdigit(c)) c=getchar();
while(isdigit(c)) x=(x*10)+(c^48),c=getchar();
return x;
}
int n,q,rt;
vector<int> NodeUp[MAXN],NodeDown[MAXN],Edge[MAXN];
int Dep[MAXN],MDep[MAXN],Son[MAXN],Top[MAXN],HighBit[MAXN];
int Fat[MAXN][21];
LL ans;
int lastans;
UI s;
UI Get(UI x){
x^=x<<13;
x^=x>>17;
x^=x<<5;
return s=x;
}
void dfs1(int now){
MDep[now]=Dep[now]=Dep[Fat[now][0]]+1;
for(auto v:Edge[now]){
Fat[v][0]=now;
for(REG int i=0;Fat[v][i];++i)
Fat[v][i+1]=Fat[Fat[v][i]][i];
dfs1(v);
if(MDep[v]>MDep[now]) MDep[now]=MDep[v],Son[now]=v;
}
}
void dfs2(int now,int top){
Top[now]=top;
if(now==top){
for(REG int i=0,f=now;i<=MDep[now]-Dep[now];++i)
NodeUp[now].push_back(f),f=Fat[f][0];
for(REG int i=0,f=now;i<=MDep[now]-Dep[now];++i)
NodeDown[now].push_back(f),f=Son[f];
}
if(Son[now]) dfs2(Son[now],top);
for(auto v:Edge[now])
if(v^Son[now]) dfs2(v,v);
}
inline int Ask(int x,int k){
if(!k) return x;
x=Fat[x][HighBit[k]],k-=(1<<HighBit[k]),k-=Dep[x]-Dep[Top[x]],x=Top[x];
return k>=0?NodeUp[x][k]:NodeDown[x][-k];
}
void Solve(){
n=read(),q=read(),s=read(),HighBit[1]=0;
for(REG int i=2;i<=n;++i)
HighBit[i]=HighBit[i>>1]+1;
for(REG int i=1;i<=n;++i)
Edge[read()].push_back(i);
rt=Edge[0][0];
dfs1(rt);
dfs2(rt,rt);
for(REG int i=1;i<=q;++i){
int x=((Get(s)^lastans)%n)+1;
int k=((Get(s)^lastans)%Dep[x]);
lastans=Ask(x,k);
ans^=1ll*i*lastans;
}
printf("%lld\n",ans);
}
int main(){
Solve();
}