题解 P5588 【小猪佩奇爬树】

· · 题解

本题为一道基础数据结构题目,主要需要掌握的算法:倍增LCA(其他快速求)。

若还不会快速求LCA的同学请先自行百度或阅读P3379 【模板】最近公共祖先(LCA)题解。

一、预处理操作

  1. 我们先进行建树操作,根节点可随意指定,此处我指定1号点为根节点。

  2. 在读入每个节点的颜色的时候我们可以将每种颜色所包含的节点进行分类统计。

  3. 我们需要预处理出每个点的子树大小(包含本身),这个也可以在dfs建树的同时完成。

二、针对每个颜色枚举答案

首先找出该颜色共有几个节点。

  1. 若没有节点,答案为:\frac{n\times(n-1)}{2} 。易证,此处不过多赘述。

  2. 若只有1个节点,答案为:所有不是该节点子树上的节点数\times该节点子树上的节点数+该节点各个子树两两节点数的乘积。

    上面那句话可能有一些难以理解,用下面的这张图稍加解释。

此处我们要求的是黄色节点的路径个数。

不在黄色节点子树上的节点数:2

黄色节点的各个子树的节点数:3,1,1

答案为:2\times(3+1+1)+3\times1+3\times1+1\times1=17

  1. 若节点数>=2且这些节点在一条链上。

    首先,我们要了解如何判断这些节点是否在一条链上,将深度最深的点与其他各个节点以此计算出最近公共祖先,若答案皆为深度最浅的那个点,那么这些点在一条链上。

    接下来,我们需要知道如何求路径个数,答案为:不在深度最浅的点的该条链所在的子树上的点的个数\times深度最深的点的子树个数。

    这句话仍然有些难以理解,我依然使用一张图来解释一下。

此处我们要求的是红色节点的路径个数。

3,4两个点中3号点的深度较浅,4号点的深度较深,

不在深度最浅的点的该条链所在的子树上的点的个数:即1,2,3,9号4个点

深度最深的点的子树个数:即4,5,6,7,8号5个点

答案为:4\times5=20

  1. 若节点数>=2且这些节点不在一条链上。

    那么我们可以把这些节点按照深度从深到浅排序,寻找深度最深的两个不存在父子关系的节点,其他的节点分别和这两个点寻找最近公共祖先,如果该节点与这两个点中的一个存在父子关系,且该节点的深度比这两个点最近公共祖先的点的深度大,则该节点一定在这两个点的路径上,反之的必定不在。如果所有点都在这条路径上,那么答案为这两个节点的子树上的点的个数的乘积,否则答案为0。

    为了便于大家理解,我继续使用一张图来解释一下。

此处我们要求的是红色节点的路径个数。

深度最深的为6号节点,然后是4,5号节点。

首先,6号与4号寻找最近公共祖先,发现是4号,它们存在父子关系,不可以。

然后,6号与5号寻找最近公共祖先,发现是3号,6号和5号不存在父子关系,可以。

接着看4号和6号的最近公共祖先:4号,4号和5号的最近公共祖先:3号,4号与6号存在父子关系,4号的深度比2号深,4号一定在6号到5号的路径上。

6号节点的子树大小:3

5号节点的子树大小:4

答案为:3\times4=12

到此,所有的情况已经分类讨论清楚。

下面附上我的代码以及部分注释,由于月赛时间紧张,代码有点丑,请谅解。

#include<bits/stdc++.h>
using namespace std;
struct q
{
    int dep,num;
};
q w[1000010];
int x,y,n[1000010],dep[1000010],l,db[2000010],nxt[2000010],len[2000010];
int f[1000010][21],fath,o,oo,fathe[1000010],ll,h,hh;
long long a,s,sn[1000010],v[1000010],vv;
vector<int> t[1000010];
bool ff;
template <typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flg=0;
    for (;!isdigit(c);c=getchar()) if (c=='-') flg=1;
    for (;isdigit(c);c=getchar()) x=x*10+c-'0';
    if (flg) x=-x;
}
inline void write(long long x)
{
    if (x<0)
    {
        putchar('-');
        x=-x;
    }
    if (x>=10) write(x/10);
    putchar(x%10+48);
}
inline void writeln(long long x)
{
    write(x);
    puts("");
}//快读、快输
bool cmp(q p,q pp)
{
    return p.dep>pp.dep;
}//按照深度从大到小排序
void wrk(int fa,int u)
{
    dep[u]=dep[fa]+1;fathe[u]=fa;
    for (int i=0;i<=19;i++)
    f[u][i+1]=f[f[u][i]][i];
    int k=len[u];
    while (k)
    {
        if (db[k]==fa)
        {
            k=nxt[k];
            continue;
        }
        f[db[k]][0]=u;
        wrk(u,db[k]);
        sn[u]+=sn[db[k]];
        k=nxt[k];
    }
}//建树,寻找子树大小
int lca(int x,int y)
{
    if (dep[x]<dep[y]) swap(x,y);
    for (int i=20;i>=0;i--)
    {
        if (dep[f[x][i]]>=dep[y]) x=f[x][i];
        if (x==y) return x;
    }
    for (int i=20;i>=0;i--)
    {
        if (f[x][i]!=f[y][i])
        {
            x=f[x][i];
            y=f[y][i];          
        }
    }
    return f[x][0];
}//寻找最近公共祖先
int main(){
    /*freopen(".in","r",stdin);
    freopen(".out","w",stdout);*/
    //ios::sync_with_stdio(false);
    read(a);
    for (int i=1;i<=a;i++)
    {
        read(n[i]);
        t[n[i]].push_back(i);
    }//读入颜色,每种颜色的节点统计
    for (int i=1;i<a;i++)
    {
        read(x);read(y);
        db[i*2-1]=y;
        nxt[i*2-1]=len[x];
        len[x]=i*2-1;
        db[i*2]=x;
        nxt[i*2]=len[y];
        len[y]=i*2;     
    }//连边
    for (int i=1;i<=a;i++) sn[i]=1;//子树大小至少为1
    wrk(0,1);//建树
    for (int i=1;i<=a;i++)
    {
        l=0;
        for (int j=0;j<t[i].size();j++)
        {
            l++;
            w[l].num=t[i][j];
            w[l].dep=dep[w[l].num];
        }//将该颜色所有的点记录下编号和深度
        if (l==1)
        {
            int k=len[w[l].num];ll=0;
            while (k)
            {
                if (db[k]==fathe[w[l].num])
                {
                    k=nxt[k];
                    continue;
                }
                ll++;
                v[ll]=sn[db[k]];
                k=nxt[k];
            }
            vv=(sn[w[l].num]-1)*(a-sn[w[l].num])+a-1;
            for (int j=1;j<=ll;j++)
                for (int m=j+1;m<=ll;m++)
                vv+=v[j]*v[m];
            writeln(vv);
            continue;
        }//只有1个节点的情况
        if (l==0)
        {
            writeln((a*(a-1))/2);
            continue;
        }//没有节点的情况
        sort(w+1,w+l+1,cmp);//按照深度从大到小排序
        hh=0;
        for (int j=2;j<=l;j++)
        {
            fath=lca(w[1].num,w[j].num);
            if (fath!=w[j].num) 
            {
                hh=j;
                break;
            }       
        }//判断是否为一条链,若hh==0则为一条链
        ff=0;
        if (hh!=0)
        for (int j=3;j<=l;j++)
        {
            o=lca(w[j].num,w[1].num);oo=lca(w[j].num,w[hh].num);
            if (((o==w[j].num)||(oo==w[j].num))&&(dep[w[j].num]>=dep[fath])) continue;
            ff=1;break;
        }//判断其他的点是否与这两个点存在父子关系
        if (ff==0)
        {
            if (hh==0)
            {
                int k=w[1].num;h=w[1].num;
                while (fathe[k]!=w[l].num)
                {
                    k=fathe[k];
                    h=k;
                }
                writeln(sn[w[1].num]*(a-sn[w[l].num]+sn[w[l].num]-sn[h]));              
            }//一条链的情况
            else writeln(sn[w[1].num]*sn[w[hh].num]);//普通情况     
        }
        else writeln(0);
    }
    return 0;
}

最后附上几组我月赛时的调试数据供大家使用。

输入1:
10
6 8 4 1 9 6 1 5 10 6 
2 1
3 1
4 1
5 3
6 5
7 1
8 1
9 4
10 3

输出1:
2
45
45
29
9
0
45
9
17
9

输入2:
5
1 2 3 2 4
1 2
2 3
3 4
2 5

输出2:
4
3
7
4
10

输入3:
10
9 6 7 6 7 2 9 10 2 6
1 2
2 4
4 7
1 3
3 5
5 6
5 8
1 9
1 10

输出3:
45
1
45
45
45
2
21
45
7
9

输入4:
10
7 2 8 9 9 10 1 6 9 6
1 2
1 7
2 3
2 4
2 9
4 5
5 6
5 8
5 10

输出4:
9
34
45
45
45
1
17
9
4
9