题解 P6177 【Count on a tree II/【模板】树分块】

· · 题解

想学习树分块的话可以先看我的 blog link。

按我的那篇博客的法三分块以后的每个询问被拆成了两个零散块和一个整块,我们记两个零散块所构成的集合分别为 AB,整块构成的集合为 C,那么答案就为 |A\cup B\cup C||A|+|B|+|C|-|A\cap B|-|B\cap C|-|A\cap C|+|A\cap B\cap C|

然后发现 |A|,|B|,|A\cap B|,|A\cap B\cap C| 可以在处理零散块的时候直接计算,剩下的 |B\cap C|,|A\cap C|,|B\cap C| 可以 O(n\sqrt{n}) 预处理。

方法是从每个块的深度最浅的点 dfs 一遍就行了。

此外 |A\cap B\cap C| 的维护方式是在 A\cap B 中找点判断与 C 是否有交,具体的判断方法见代码。

复杂度 O((n+q)\sqrt{n})

#include<bits/stdc++.h>
using namespace std;

#define il inline
#define rg register

il int read()
{
    int re=0,k=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')k=-1;ch=getchar();}
    while(ch<='9'&&ch>='0'){re=re*10+ch-48;ch=getchar();}
    return re*k;
}

il void write(int x)
{
    if(x<0)return putchar('-'),write(-x),void();
    if(x<10)return putchar(x+48),void();
    return write(x/10),write(x%10),void();
}

const int BB=351;

int h[100005],l[100005],head[100005],tot,d[100005],fa[100005],top[100005],n,a[100005],sz[100005],son[100005],f[405][100005],p[405],ls,t[100005],c[100005],g[100005];
int fs[2000005],tt;
bool vis[100005];
int vis2[100005],vis1[100005];
struct ss{
    int node,nxt;
}e[100005];
void add(int u,int v)
{
    e[++tot].nxt=head[u];
    e[tot].node=v;
    head[u]=tot;
}
void dfs(int x,int ffa,int dp)
{
    fa[x]=ffa;
    d[x]=dp;
    sz[x]=1;

    int maxsz=-1;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int k=e[i].node;
        if(k==ffa)continue;
        if(!vis[k])
            g[k]=g[x];
        else g[k]=k;
        dfs(k,x,dp+1);
        sz[x]+=sz[k];
        if(maxsz<sz[k])maxsz=sz[k],son[x]=k;
    }
}
void ddfs(int x,int tt)
{
    top[x]=tt;
    l[x]=h[a[x]];//预处理头顶第一个和它相同的点的
    if(!son[x])return;
    h[a[x]]=x;
    ddfs(son[x],tt);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int k=e[i].node;
        if(fa[x]==k||son[x]==k)continue;
        ddfs(k,k);
    }
    h[a[x]]=l[x];
}
int lca(int x,int y)//树剖LCA
{
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])swap(x,y);
        x=fa[top[x]];
    }
    if(d[x]>d[y])swap(x,y);
    return x;
}

void dp(int x,int fa,int *ff,int *s,int ls,int ans)//从每个点开始进行预处理
{
    if(!x)return;
    s[++ls]=a[x];
    if(t[a[x]]&&!vis2[a[x]])ans++;
    vis2[s[ls]]++;
    if(vis[x])//预处理 C
    {
        for(int i=1;i<=ls;i++)
        {
            if(!t[s[i]])tt++;
            t[s[i]]++,vis2[s[i]]--;
        }
        ff[x]=tt;
    }
    else ff[x]=ans;//预处理 A 交 B
    for(rg int i=head[x];i;i=e[i].nxt)
    {
        if(e[i].node==fa)continue;
        if(vis[x])
        {   
            dp(e[i].node,x,ff,s+ls,0,0);
        }
        else
        {
            dp(e[i].node,x,ff,s,ls,ans);
        }
    }
    if(vis[x])
    {
        for(int i=1;i<=ls;i++)
        {
            t[s[i]]--;vis2[s[i]]++;
            if(!t[s[i]])tt--;
        }
    }
    vis2[s[ls]]--;
}
int sol(int u,int v,int typ)
{
    int lans=0;
    int x=u,y=v,LCA=lca(u,v);
    while(x!=LCA){if(!vis1[a[x]])lans++;vis1[a[x]]=1;x=fa[x];}
    while(y!=LCA){if(!typ||y!=v){if(!vis1[a[y]])lans++;vis1[a[y]]=1;}y=fa[y];}
    if(!vis1[a[LCA]]&&(!typ||LCA!=v))lans++;
    while(u!=LCA){vis1[a[u]]=0;u=fa[u];}
    while(v!=LCA){vis1[a[v]]=0;v=fa[v];}
    return lans;
}

int b[100005],cnt,q;

int main()
{
    freopen("2.in","r",stdin);
    freopen("2.out","w",stdout);
    n=read();q=read();
    for(rg int i=1;i<=n;i++)
    {
        b[++cnt]=a[i]=read();
        c[i]=i;
    }
    for(rg int i=1,u,v;i<n;i++)
    {
        u=read();v=read();
        add(u,v);add(v,u);
    }
    srand(time(0));
    sort(b+1,b+cnt+1);
    cnt=unique(b+1,b+cnt+1)-b-1;
    random_shuffle(c+2,c+n+1);
    for(rg int i=1;i<=n;i++)
    {
        if(i<=BB)p[i]=c[i],vis[c[i]]=1;
        a[i]=lower_bound(b+1,b+cnt+1,a[i])-b;
    }
    g[1]=1;
    dfs(1,0,1);
    ddfs(1,1);
    for(rg int i=1;i<=BB;i++)
    {
        dp(p[i],0,f[i],fs,0,0);
    }
    int lans=0;
    memset(vis1,0,sizeof(vis1));
    memset(vis2,0,sizeof(vis2));
    for(rg int i=1,u,v;i<=q;i++)
    {
        u=read()^lans;v=read();
        if(g[u]==g[v])//零散块
        {
            lans=0;
            int x=u,y=v,LCA=lca(u,v);
            while(x!=LCA){if(!vis1[a[x]])lans++;vis1[a[x]]=1;x=fa[x];}
            while(y!=LCA){if(!vis1[a[y]])lans++;vis1[a[y]]=1;y=fa[y];}
            if(!vis1[a[LCA]])lans++;
            while(u!=LCA){vis1[a[u]]=0;u=fa[u];}
            while(v!=LCA){vis1[a[v]]=0;v=fa[v];}
            write(lans);puts("");continue;
        }
        int LCA=lca(u,v);
        if(d[g[u]]<d[LCA]||d[g[v]]<d[LCA])//两个块之间是直系亲属关系
        {
            if(d[g[u]]<d[LCA])swap(u,v);
            int x=u,y=v,z=u,A=0,B=0,C=0,AB=0,BC=0,AC=0,ABC=0,xx=u,yy=v,ty;
            while(!vis[x]&&x!=LCA)
            {
                if(!vis1[a[x]])
                    A++;
                vis1[a[x]]=x;
                x=fa[x];
            }
            while(!vis[y]&&y!=LCA)
            {
                if(!vis2[a[y]])
                {
                    B++;
                    if(vis1[a[y]])
                    {
                        AB++;
                    }
                }
                vis2[a[y]]=y;y=fa[y];
            }
            while(d[g[fa[g[z]]]]>d[LCA])z=fa[g[z]];
            ty=z=g[z];
            z=fa[z];
            int zz=z;
            while(!vis[z]&&z!=LCA)
            {
                if(!vis2[a[z]])
                {
                    B++;
                    if(vis1[a[z]])
                    {
                        AB++;
                    }
                }
                vis2[a[z]]=z;z=fa[z];
            }
            if(!vis2[a[LCA]])
            {
                B++;
                if(vis1[a[LCA]])AB++;
            }
            vis2[a[LCA]]=LCA;
            while(!vis[yy]&&yy!=LCA)
            {
                if(vis2[a[yy]]&&vis1[a[yy]])
                {
                    if((d[l[vis1[a[yy]]]]>=d[ty]&&d[l[vis1[a[yy]]]]<=d[x]))
                        ABC++;
                }
                vis2[a[yy]]=0;yy=fa[yy];
            }
            while(!vis[zz]&&zz!=LCA)
            {
                if(vis2[a[zz]]&&vis1[a[zz]])
                {
                    if((d[l[vis1[a[zz]]]]>=d[ty]&&d[l[vis1[a[zz]]]]<=d[x]))
                        ABC++;
                }
                vis2[a[zz]]=0;zz=fa[zz];
            }
            if(vis2[a[LCA]])
            {
                if(vis1[a[LCA]])
                    if((d[l[vis1[a[LCA]]]]>=d[ty]&&d[l[vis1[a[LCA]]]]<=d[x]))
                        ABC++;
            }
            vis2[a[LCA]]=0;
            while(!vis[xx]&&xx!=LCA){vis1[a[xx]]=0;xx=fa[xx];}
            for(rg int j=1;j<=BB;j++)
                if(p[j]==x){x=j;break;}
            for(rg int j=1;j<=BB;j++)
                if(p[j]==ty){ty=j;break;}
            C=f[x][p[ty]];
            if(!vis[v])BC=f[x][v];
            if(!vis[u])AC=f[ty][u];
            lans=A+B+C-AB-BC-AC+ABC;
            write(lans);puts("");continue;
        }
        //两块之间不是直系亲属关系
        int x=u,y=v,A=0,B=0,C=0,AB=0,BC=0,AC=0,ABC=0,xx=u,yy=v;
        while(!vis[x]&&x!=LCA)
        {
            if(!vis1[a[x]])
                A++;
            vis1[a[x]]=x;
            x=fa[x];
        }
        while(!vis[y]&&y!=LCA)
        {
            if(!vis2[a[y]])
            {
                B++;
                if(vis1[a[y]])
                {
                    AB++;
                }
            }
            vis2[a[y]]=y;y=fa[y];
        }
        while(!vis[yy]&&yy!=LCA)
        {
            if(vis2[a[yy]]&&vis1[a[yy]])
            {
                if(((d[l[vis1[a[yy]]]]>=d[LCA])&&(d[l[vis1[a[yy]]]]<=d[x]))||((d[l[vis2[a[yy]]]]>=d[LCA])&&(d[l[vis2[a[yy]]]]<=d[y])))
                    ABC++;
            }
            vis2[a[yy]]=0;yy=fa[yy];
        }
        while(!vis[xx]&&xx!=LCA){vis1[a[xx]]=0;xx=fa[xx];}
        for(rg int j=1;j<=BB;j++)
            if(p[j]==x){x=j;break;}
        for(rg int j=1;j<=BB;j++)
            if(p[j]==y){y=j;break;}
        C=f[x][p[y]];
        if(!vis[v])BC=f[x][v];
        if(!vis[u])AC=f[y][u];
        lans=A+B+C-AB-BC-AC+ABC;
        write(lans);puts("");
    }
}