P10180 题解

· · 题解

O(n\sqrt{q}) 做法

不难发现本题实际上是要算:

c_x 表示颜色 x 出现的次数。

那么,对于一组询问 x,y,如果我们能设计出 O(\min(c_x,c_y)) 的算法并进行记忆化,总的复杂度就不会超过 O(n\sqrt{q})

对每组询问 (x,y),若 c_x<c_y 则交换 x,y,我们将这组询问 (x,y) 挂在颜色 x 上。接下来对每种颜色 x 分别处理:考虑维护树上的点集形成的连通块,我们将每个连通块的信息放在这个连通块的根的位置,则插入一个点时只需考虑它的若干儿子,以及它的父亲处可能存在的信息合并。

考虑提前预处理出每个颜色为 x 的连通块的根,接下来按照深度从大到小依次插入所有颜色为 y 的点。这样插入一个点时,父亲处若存在信息合并,必然是完整的 x 连通块,我们可以在 O(1) 的时间内将其若干儿子的连通块的信息提到总的连通块的根上面,就在 O(1) 时间内完成了插入。

总的时间复杂度为 O(n\sqrt{q})

O(n+q) 做法

可以发现,对于给出的颜色 x_i,y_i ,若树上不存在一条边 i 满足 i 两端点的颜色分别是 x_i,y_i ,那么 x_i,y_i 的点会形成若干独立的连通块,可以预处理后直接计算。

其余的颜色对一共最多只有 n-1 种,就是每条边端点的颜色对集合。

首先把两端点颜色相同的边缩点。

枚举每一种可能的颜色对 x_i,y_i,把树上这样的边连上,计算形成的连通块的大小的平方和即可。

使用哈希表存答案,然后建边后再树上跑 DFS,可以做到 O(n+m),如果使用 std::map 或者可撤销并查集也能通过。

#include<bits/stdc++.h>
using namespace std;
const int N = 1e6+7;
int n,m,q;
int c[N];
int idx=0;struct dsu
{
    int fa[N],siz[N];
    int find(int x)
    {
        if(x==fa[x])return x;
        return fa[x]=find(fa[x]);
    }
    void merge(int x,int y)
    {
        if(find(x)==find(y))return;
        x=find(x);y=find(y);
        fa[x]=y;
        siz[y]+=siz[x];
    }
}A,B;
struct edge 
{
    int a,b,next,id;
}e[N];
const int M = 1e6+7;
int flink[M],t=0;
int get(int a,int b)
{
    int h=(1ll*a*131%M+b)%M;
    for(int i=flink[h];i;i=e[i].next)
    if(e[i].a==a&&e[i].b==b)return e[i].id;
    e[++t].a=a;
    e[t].b=b;
    e[t].id=++idx;
    e[t].next=flink[h];
    flink[h]=t;
    return idx;
}
int qry(int a,int b)
{
    int h=(1ll*a*131%M+b)%M;
    for(int i=flink[h];i;i=e[i].next)
    if(e[i].a==a&&e[i].b==b)return e[i].id;
    return 0;
}
#define PII pair<int,int>
#define mk(x,y) make_pair(x,y)
#define X(x) x.first
#define Y(x) x.second
typedef long long LL;
inline int read() {
    char ch = getchar(); int x = 0;
    while (!isdigit(ch)) {ch = getchar();}
    while (isdigit(ch)) {x = x * 10 + ch - 48; ch = getchar();}
    return x;
}
void write(LL x) {
    if (!x) return;
    write(x / 10); putchar(x % 10 + '0');
}
inline void print(LL x, char ch = '\n') {
    if (!x) putchar('0');
    else write(x);
    putchar(ch);
}
vector<int> E[N];
LL ans[N];
int U[N],V[N];
int seq[2*N],tot=0;
bool mark[N];
LL ext[N];
int vis[N],tag;
int main()
{
    n = read(); q = read();
    for(int i=1;i<=n;i++)
    {
        c[i] = read();
        A.fa[i]=i;
        A.siz[i]=1;
    }
    for(int i=2;i<=n;i++)
    {
        int x;
        x = read();
        if(c[i]==c[x]) A.merge(x,i);
        else 
        {
            int cx=c[x],cy=c[i];
            if(cx>cy)swap(cx,cy);
            ++m;
            U[m]=x;
            V[m]=i;
            E[get(cx,cy)].push_back(m);
        }
    }
    for(int i=1;i<=m;i++)
    {
        U[i]=A.find(U[i]);
        V[i]=A.find(V[i]);
        mark[U[i]]=1;
        mark[V[i]]=1;
    }
    for(int i=1;i<=n;i++)
    if(A.find(i)==i)
    ext[c[i]]+=1ll*A.siz[i]*A.siz[i];
    for(int r=1;r<=idx;r++)
    {
        tot=0;++tag; 
        for(auto p:E[r])
        {
            int x=U[p],y=V[p];
            if(vis[x]!=tag)vis[x]=tag,seq[++tot]=x;
            if(vis[y]!=tag)vis[y]=tag,seq[++tot]=y;
        }
        LL res=0;
        for(int i=1;i<=tot;i++)
        {
            B.fa[seq[i]]=seq[i];
            B.siz[seq[i]]=A.siz[seq[i]];
            res-=1ll*A.siz[seq[i]]*A.siz[seq[i]];
        }
        for(auto p:E[r])
        {
            int x=U[p],y=V[p];
            B.merge(x,y);
        }
        for(int i=1;i<=tot;i++)
        {
            int x=seq[i];
            if(B.find(x)==x)
            res+=1ll*B.siz[x]*B.siz[x];
        }
        ans[r]=res;
    }
    while(q--)
    {
        int x,y;
        x = read(); y = read();
        assert(x != y);
        if(x>y)swap(x,y);
        LL res=ext[x]+ext[y];
        if(qry(x,y))res+=ans[qry(x,y)];
        print(res);
    }
    return 0;
}