题解 P6177 Count on a tree II/【模板】树分块

· · 题解

upd on 2022.2.23:修了个 typo。

洛谷题面传送门

好家伙,在做这道题之前我甚至不知道有个东西叫树分块

树分块,说白了就是像对序列分块一样设一个阈值 B,然后在树上随机撒 \dfrac{n}{B} 个关键点,满足任意一个点到距离其最近的关键点距离不超过 \mathcal O(B) 级别,这样我们就可以预处理关键点两两之间的信息,然后询问两个点路径上的信息时直接将预处理的信息拿出来使用,再额外加上两个端点到距离它们最近的关键点之间的路径的贡献即可算出答案,复杂度 \mathcal O(B^2+qB+\dfrac{n^2}{B}),一般 B\sqrt{n}

当然这个“随机撒点”也并不用什么高超的玄学技巧,甚至不用随机()。一个很显然的想法是将深度 \bmod B=0 设为关键点,但稍微懂点脑子即可构造出反例:一条长度为 B 的链下面挂 n-B 个叶子。因此考虑对上面的过程做点手脚,我们记 dis_ii 到距离它最的叶子节点的距离,那么我们将深度 \bmod B=0dis_i\ge B 的点 i 设为关键节点即可,容易证明在这种构造方法下关键点个数是严格 \dfrac{n}{B} 级别的,读者自证不难,复杂度也就得到了保证。

接下来考虑怎样求出答案,我们首先预处理出每堆关键节点路径上颜色的情况,用 bitset 维护,该操作显然可以以 \mathcal O(\dfrac{n^2}{B}) 的及 \mathcal O(\dfrac{n^3}{B^2\omega}) 空间复杂度完成。那么对于一组询问 x,y,记 x 祖先中离其最近的关键点为 fxfy 的定义类似,l=\text{LCA}(x,y),分三种情况:

时间复杂度 \mathcal O(\dfrac{n^2}{B}+qB+\dfrac{n^3}{B^2\omega})=\mathcal O(n\sqrt{n}+\dfrac{n^2}{\omega}),常数略有点大,但实际跑起来不算慢(

const int MAXN=4e4;
const int LOG_N=16;
const int BLK=200;
int n,qu,a[MAXN+5],key[MAXN+5],uni[MAXN+5],num=0;
int hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int id[MAXN+5],pcnt=0,blk,buc[MAXN+5];
int dis[MAXN+5],dep[MAXN+5],fa[MAXN+5][LOG_N+2];
bitset<MAXN+5> col[21000],vis;
void dfs0(int x,int f){
    fa[x][0]=f;
    for(int e=hd[x];e;e=nxt[e]){
        int y=to[e];if(y==f) continue;dep[y]=dep[x]+1;
        dfs0(y,x);chkmax(dis[x],dis[y]+1);
    } if(dep[x]%blk==0&&dis[x]>=blk) id[x]=++pcnt;
}
int qwq[BLK+5][BLK+5],cc=0;
void dfs(int x,int f,int rt){
    buc[a[x]]++;if(buc[a[x]]==1) vis.set(a[x]);
    if(id[x]) col[qwq[id[rt]][id[x]]]=vis;
    for(int e=hd[x];e;e=nxt[e]){
        int y=to[e];if(y==f) continue;
        dfs(y,x,rt);
    } buc[a[x]]--;if(!buc[a[x]]) vis.reset(a[x]);
}
int getlca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=LOG_N;~i;i--) if(dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
    if(x==y) return x;
    for(int i=LOG_N;~i;i--) if(fa[x][i]^fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
int jump(int x){
    while(x){
        if(id[x]) return x;
        x=fa[x][0];
    } return 0;
}
int get_kanc(int x,int k){
    for(int i=LOG_N;~i;i--) if(k>>i&1) x=fa[x][i];
    return x;
}
int main(){
//  freopen("C:\\Users\\汤\\Downloads\\P6177_1.in","r",stdin);
//  cout<<sizeof(col)<<endl;
    scanf("%d%d",&n,&qu);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]),key[i]=a[i];
    key[0]=-1;sort(key+1,key+n+1);
    for(int i=1;i<=n;i++) if(key[i]^key[i-1]) uni[++num]=key[i];
    for(int i=1;i<=n;i++) a[i]=lower_bound(uni+1,uni+num+1,a[i])-uni;
    for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),adde(u,v),adde(v,u);
    blk=(int)sqrt(n);dfs0(1,0);dep[0]=-1;
    for(int i=1;i<=pcnt;i++) for(int j=1;j<=i;j++) qwq[i][j]=qwq[j][i]=++cc;
    for(int i=1;i<=n;i++) if(id[i]) memset(buc,0,sizeof(buc)),dfs(i,0,i);
    for(int i=1;i<=LOG_N;i++) for(int j=1;j<=n;j++)
        fa[j][i]=fa[fa[j][i-1]][i-1];
    int pre=0;
    while(qu--){
        int x,y;scanf("%d%d",&x,&y);x^=pre;vis.reset();
        int l=getlca(x,y),fx=jump(x),fy=jump(y);
        if(dep[fx]<dep[l]&&dep[fy]<dep[l]){
            while(x!=l) vis.set(a[x]),x=fa[x][0];
            while(y!=l) vis.set(a[y]),y=fa[y][0];
            vis.set(a[l]);
            printf("%d\n",pre=vis.count());
        } else if(dep[fx]>=dep[l]&&dep[fy]>=dep[l]){
            vis=col[qwq[id[fx]][id[fy]]];
            while(x!=fx) vis.set(a[x]),x=fa[x][0];
            while(y!=fy) vis.set(a[y]),y=fa[y][0];
            printf("%d\n",pre=vis.count());
        } else{
            if(dep[fy]>=dep[l]) swap(x,y),swap(fx,fy);
            assert(dep[fy]<dep[l]);
            int z=get_kanc(x,max(dep[x]-dep[l]-(blk<<1|1),0));
            int near=-1;
            while(z!=l){
                if(id[z]) near=z;
                z=fa[z][0];
            } if(id[l]) near=l;
            assert(~near);
            if(fx!=near) vis=col[qwq[id[fx]][id[near]]];
            while(x!=fx) vis.set(a[x]),x=fa[x][0];
            while(near!=l) vis.set(a[near]),near=fa[near][0];
            while(y!=l) vis.set(a[y]),y=fa[y][0];
            vis.set(a[l]);
            printf("%d\n",pre=vis.count());
        }
    }
    return 0;
}