P6177 Count on a tree II/【模板】树分块 题解

· · 题解

写了一点神奇的东西,具体来说是 @critnos 的代码实现。

这题确实有可以跑得飞快的平方除以 \omega 算法,以及跑得更快的平方 \log 除以 \omega 算法。但是既然是树分块模板题,那么我就写了个常数超大的树分块。

1. 寻找关键点

我们随机出 \frac{\sqrt n}{2} 个点,建立出它们的虚树。这样一来,我们可以做到每个点向上跳直到跳到关键点上所需要跳的次数是期望 O(\sqrt n),关键点个数是 O(\sqrt n),并且对于每两个关键点,它们的最近公共祖先也是关键点。由于我们要所有点向上跳都有关键点,所以我们可以在一开始钦定节点 1 为关键点,后面处理的时候以 1 号节点为根。

2. 具体做法

我们在找出 O(\sqrt n) 个关键点之后,首先要记录关键点之间的答案。我们可以开一个记录颜色的桶,以每一个关键点为根都跑一次深搜,像莫队转移一样记录搜索到每个点时它到当前的根节点的路径上有多少种不同的颜色。这样总的时间复杂度是 O(n\sqrt n) 的。

接下来,我们以 1 为根开始深搜,预处理出每个点的祖先(不算它自己)中,离它最近的关键点是哪一个。同时,我们可以用记录颜色的桶记录下来从 1 号节点到每一个关键点的路径上,每种颜色出现了几次。时间复杂度 O(n\sqrt n)

然后处理完了这些,我们可以开始查询。查询主要就是暴力计算散点,对于 O(\sqrt n) 的散点,分别查询它们在中间的整块上出现了几次。这可以用树上前缀和来维护。

具体来说,每次查询,我们对于还原后的 x,y,求出它们的祖先(包括他们自己)中最近的关键点 fx,fy,以及 fx,fy 的最近公共祖先 lc。然后开始分类讨论:

  1. 其他情况。此时我们一定有 fx=lcfy=lc 恰有一个满足,由于 x,y 对称,我们不妨假设 fx=lc。那么我们可以从 fy 开始,不断往上跳最近的关键点直到再跳一次关键点的深度就会小于等于 x 的深度。假设这个关键点是 ty,那么答案就是 x\to lca(x,ty) 的散点,ty\to lca(x,ty) 的散点和 y\to fy 的散点,再加上 fy\to ty 的整块。时间复杂度 O(\sqrt n)

在跳的时候,为了保证每种颜色只被算一次,我们可以维护一个桶记录我们已经算过了哪些颜色。为了避免复杂度退化,我们清空桶的时候再重复以上的跳法就好了。

3. 代码

主要长度在于有很多重复内容并且很好理解的分类讨论。

#include<bits/stdc++.h>
using namespace std;
const int _=1e5+5,o=210;
int n,cnt,v,k,c,m,t,a[_],b[_],f[18][_],lg[_],p[_],q[_];
int dep[_],dfn[_],tot,nr[_],fa[_],h[_],s[o][_],pa[o][o];
vector<int>e[_];
int get(int x,int y){
    return dep[x]<dep[y]?x:y;
}
void dfs(int x,int fat){
    f[0][dfn[x]=++cnt]=x;
    dep[x]=dep[fa[x]=fat]+1;
    for(auto y:e[x])
        if(y!=fat)dfs(y,x);
}
void pas(int x,int fat,int rt){
    if(!h[a[x]]++)tot++;
    if(p[x])pa[rt][p[x]]=tot;
    for(auto y:e[x])
        if(y!=fat)pas(y,x,rt);
    if(!--h[a[x]])tot--;
}
void pre(int x,int nea){
    nr[x]=nea;h[a[x]]++;
    if(p[x])
        for(int i=1;i<=v;i++)s[p[x]][i]=h[i];
    for(auto y:e[x])
        if(y!=fa[x])pre(y,p[x]?x:nea);
    h[a[x]]--;
}
int lca(int x,int y){
    if(x==y)return x;
    if(dfn[x]>dfn[y])swap(x,y);
    int ln=lg[dfn[y]-dfn[x]];
    return fa[get(f[ln][dfn[x]+1],f[ln][dfn[y]-(1<<ln)+1])];
}
int main(){
    srand(time(0));
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    cin>>n>>m;c=k=sqrt(n)/2;
    for(int i=1;i<=n;i++)cin>>a[i],b[i]=a[i];
    sort(b+1,b+n+1);v=unique(b+1,b+n+1)-b-1;
    for(int i=1;i<=n;i++)a[i]=lower_bound(b+1,b+v+1,a[i])-b;
    for(int i=1,x,y;i<n;i++)cin>>x>>y,e[x].push_back(y),e[y].push_back(x);
    dfs(1,0);
    for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;
    for(int i=1;i<=16;i++)
        for(int j=1;j<=n-(1<<i-1);j++)f[i][j]=get(f[i-1][j],f[i-1][j+(1<<i-1)]);
    p[q[1]=1]=p[0]=1;
    for(int i=2;i<=k;i++){
        while(p[q[i]])q[i]=rand()%(n-1)+1;
        p[q[i]]=1;
    }
    sort(q+1,q+k+1,[&](int x,int y){return dfn[x]<dfn[y];});
    for(int i=1;i<k;i++)q[++c]=lca(q[i],q[i+1]);
    sort(q+1,q+c+1);c=unique(q+1,q+c+1)-q-1;
    for(int i=1;i<=c;i++)p[q[i]]=i;
    for(int i=1;i<=c;i++)pas(q[i],0,i);
    pre(1,0);
    for(int x,y,lc,fx,fy,ty,ans,la=0;m;m--){
        cin>>x>>y;x^=la;
        fx=p[x]?x:nr[x];fy=p[y]?y:nr[y];
        lc=lca(fx,fy);ans=0;
        if(lc!=fx&&lc!=fy){
            for(int i=x;i!=fx;i=fa[i])
                if(!h[a[i]]){
                    ans+=s[p[fx]][a[i]]+s[p[fy]][a[i]]+(a[i]==a[lc])==2*s[p[lc]][a[i]];
                    h[a[i]]=1;
                }
            for(int i=y;i!=fy;i=fa[i])
                if(!h[a[i]]){
                    ans+=s[p[fx]][a[i]]+s[p[fy]][a[i]]+(a[i]==a[lc])==2*s[p[lc]][a[i]];
                    h[a[i]]=1;
                }
            for(int i=x;i!=fx;i=fa[i])h[a[i]]=0;
            for(int i=y;i!=fy;i=fa[i])h[a[i]]=0;
            cout<<(la=ans+pa[p[fx]][p[fy]])<<'\n';
        }else if(fx==fy){
            lc=lca(x,y);
            for(int i=x;i!=fa[lc];i=fa[i])
                if(!h[a[i]])ans++,h[a[i]]=1;
            for(int i=y;i!=lc;i=fa[i])
                if(!h[a[i]])ans++,h[a[i]]=1;
            for(int i=x;i!=fa[lc];i=fa[i])h[a[i]]=0;
            for(int i=y;i!=lc;i=fa[i])h[a[i]]=0;
            cout<<(la=ans)<<'\n';
        }else{

            if(fy==lc)swap(x,y),swap(fx,fy);ty=fy;
            while(dep[nr[ty]]>dep[x])ty=nr[ty];
            lc=lca(x,y);
            for(int i=x;i!=fa[lc];i=fa[i])
                if(!h[a[i]]){
                    ans+=s[p[fy]][a[i]]+(a[i]==a[ty])==s[p[ty]][a[i]];
                    h[a[i]]=1;
                }
            for(int i=ty;i!=lc;i=fa[i])
                if(!h[a[i]]){
                    ans+=s[p[fy]][a[i]]+(a[i]==a[ty])==s[p[ty]][a[i]];
                    h[a[i]]=1;
                }
            for(int i=y;i!=fy;i=fa[i])
                if(!h[a[i]]){
                    ans+=s[p[fy]][a[i]]+(a[i]==a[ty])==s[p[ty]][a[i]];
                    h[a[i]]=1;
                }
            for(int i=x;i!=fa[lc];i=fa[i])h[a[i]]=0;
            for(int i=ty;i!=lc;i=fa[i])h[a[i]]=0;
            for(int i=y;i!=fy;i=fa[i])h[a[i]]=0;
            cout<<(la=ans+pa[p[fy]][p[ty]])<<'\n';
        }
    }
    return 0;
}

空间复杂度 O(n\sqrt n),时间复杂度 O((n+m)\sqrt n)