题解 P5903 【【模板】树上 k 级祖先】

· · 题解

树上 k 级祖先

P5903 树上 k 级祖先

先讲长链剖分。对于树上的节点,它所保存的与树形态有关的信息有两个,一个是子树大小,一个是向下延伸链长。回顾重链剖分,实际上就是利用了前者,那么长链剖分实际上就是利用了后者,它的“长”儿子实际上是向下延伸链长最大的儿子。

容易发现长链剖分的过程是 O(n) 的。我们考虑如何使用长链剖分解决树上 k 级祖先问题。

考虑倍增求解该问题的过程,我们希望询问做到 O(1) 而非 O(\log n)。倍增将时间浪费在了跳的过程中。有没有这样一种方法,我们先利用倍增思想跳出最开始那最大的一步,后面的一步搞定呢?我们来看长链剖分后的一条性质:

一个节点的 k 级祖先所在的长链长大于等于 k

证明十分显然。有了这个结论后我们再来看原问题,我们设 k 的二进制最高位为 \log_2h,那么我们先跳这一步,然后将问题转换为求 k-h 级祖先。不难发现,跳完之后的当前节点所在的长链长一定大于 h,因为它是初始节点的 h 级祖先。那么进一步地,由于 h>k-h,跳完之后的当前节点所在的长链长一定大于 k-h

这启发我们设计一个算法:

  1. 长链剖分,对于每条链的链头,设该链长为 l,我们预处理出这条链的所有节点,以及链头向上 l 个节点,容易发现最后只会记录 2n 个节点,不影响 O(n) 的复杂度。
  2. 对每个点处理倍增数组,这一步是 O(n\log n) 的。
  3. 求出 \text{High-bit}(i),i\in[1,n],这里是可以做到 O(n) 的。
  4. 对于询问 xk 级祖先,先向上跳 2^{\text{High-bit(k)}} 级,然后根据 \text{dep} 直接判断待求节点在链头上还是下,然后直接跳。

复杂度显然是 O(n\log n+q) 的。

\text{Code}
#include<bits/stdc++.h>
#define REG register
#define LL long long
#define UI unsigned int
#define MAXN 500005
using namespace std;
inline int read(){
    REG int x(0);
    REG char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) x=(x*10)+(c^48),c=getchar();
    return x;
}

int n,q,rt;
vector<int> NodeUp[MAXN],NodeDown[MAXN],Edge[MAXN];
int Dep[MAXN],MDep[MAXN],Son[MAXN],Top[MAXN],HighBit[MAXN];
int Fat[MAXN][21];
LL ans;
int lastans;

UI s;
UI Get(UI x){
    x^=x<<13;
    x^=x>>17;
    x^=x<<5;
    return s=x;
}

void dfs1(int now){
    MDep[now]=Dep[now]=Dep[Fat[now][0]]+1;
    for(auto v:Edge[now]){
        Fat[v][0]=now;
        for(REG int i=0;Fat[v][i];++i)
            Fat[v][i+1]=Fat[Fat[v][i]][i];
        dfs1(v);
        if(MDep[v]>MDep[now]) MDep[now]=MDep[v],Son[now]=v;
    }
}

void dfs2(int now,int top){
    Top[now]=top;
    if(now==top){
        for(REG int i=0,f=now;i<=MDep[now]-Dep[now];++i)
            NodeUp[now].push_back(f),f=Fat[f][0];
        for(REG int i=0,f=now;i<=MDep[now]-Dep[now];++i)
            NodeDown[now].push_back(f),f=Son[f];
    }
    if(Son[now]) dfs2(Son[now],top);
    for(auto v:Edge[now])
        if(v^Son[now]) dfs2(v,v);
}

inline int Ask(int x,int k){
    if(!k) return x;
    x=Fat[x][HighBit[k]],k-=(1<<HighBit[k]),k-=Dep[x]-Dep[Top[x]],x=Top[x];
    return k>=0?NodeUp[x][k]:NodeDown[x][-k];
}

void Solve(){
    n=read(),q=read(),s=read(),HighBit[1]=0;
    for(REG int i=2;i<=n;++i)
        HighBit[i]=HighBit[i>>1]+1;
    for(REG int i=1;i<=n;++i)
        Edge[read()].push_back(i);
    rt=Edge[0][0];
    dfs1(rt);
    dfs2(rt,rt);
    for(REG int i=1;i<=q;++i){
        int x=((Get(s)^lastans)%n)+1;
        int k=((Get(s)^lastans)%Dep[x]);
        lastans=Ask(x,k);
        ans^=1ll*i*lastans;
    }
    printf("%lld\n",ans);
}

int main(){
    Solve();
}