AtCoder Beginner Contest 202 - E Count Descendants 解题报告

· · 题解

写在前面

扫了一眼题解区,居然没有分块在线?这不得补上一篇?

后天 CSP,祝各位 RP+=(RP++ + ++RP)

解题思路

其实还真挺一眼的。把问题转化一下就变成了求 U_i 的子树中有多少个深度为 D_i+1 的点,注意是 D_i+1。把 dfs 序拎出来就变成了求区间内某个值的数量。于是将按 dfs 序转换后的深度序列分块,块长 \sqrt N。预处理出一个深度前缀和序列,不过记录的是前 i 个块中不同深度点的数。

m=sqrt(n); k=n/m;
for(int i=1;i*m<=n;++i){
    memset(sum1,0,sizeof sum1);
    for(int j=(i-1)*m+1;j<=i*m;++j) ++sum1[dep[j]];
    for(int j=1;j<=n;++j) sum[i][j]=sum[i-1][j]+sum1[j];
    ls[i]=(i-1)*m+1; rs[i]=i*m;
}
if(m*m!=n){
    memset(sum1,0,sizeof sum1);
    for(int j=n/m*m+1;j<=n;++j) ++sum1[dep[j]];
    ++k;
    for(int j=1;j<=n;++j) sum[k][j]=sum[k-1][j]+sum1[j];
    ls[k]=(k-1)*m+1; rs[k]=n;       //ls,rs分别是左端点和右端点
}

(为了方便理解贴份代码)

对于一个询问 (U_i,D_i),将其转化为序列上的一段询问很简单,相当于是求 [U_i,U_i+siz_{U_i}-1] 中值为 D_i 个元素个数,用之前预处理的前缀和,外加两段最长在 2\sqrt N 的暴力求解即可。

int calc(int l,int r,int c){
    int res=0;
    int L=(l+m-2)/m+1,R=r/m;
    res+=sum[R][c]-sum[L-1][c];
    for(int i=l;i<=rs[L-1];++i) if(dep[i]==c) ++res;
    for(int i=ls[R+1];i<=r;++i) if(dep[i]==c) ++res;
    return res;
}

这里有个小问题值得注意。当 N 是一个完全平方数的时候,如果询问区间在 [N,N],那么 R+1=\sqrt N +1ls_{R+1} 就会等于 0,答案就是错的。但是 AT 的数据似乎没有卡掉这个错,还是后来我自己偶然间测出来的(本来是想找之前代码为什么挂了,结果放在了第一版的 AC 代码上跑,发现把自己 HACK 了)。虽然过了,但是万一什么时候就炸了呢?所以应该把最后的块的下一个(实际上不存在的)块的左端点设大一点。

ls[k+1]=n+5;

AC 代码

#include<bits/stdc++.h>
#define ll long long
#define db double
#define ull unsigned ll
#define ldb long db
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pil pair<int,ll>
#define pdd pair<db,db>
#define F first
#define S second
#define PB push_back
#define MP make_pair
using namespace std;
const int N=2e5+5,M=505;
int n,m,k,p,q,u,d,siz[N],dep[N],id[N],cnt,sum[M][N],sum1[N],ls[M],rs[M];
vector<int> e[N];
void dfs(int cur,int fa){
    siz[cur]=1;
    id[cur]=++cnt;
    dep[id[cur]]=dep[id[fa]]+1;
    for(int to:e[cur]){
        dfs(to,cur);
        siz[cur]+=siz[to];
    }
}
int calc(int l,int r,int c){
    int res=0;
    int L=(l+m-2)/m+1,R=r/m;
    res+=sum[R][c]-sum[L-1][c];
    for(int i=l;i<=rs[L-1];++i) if(dep[i]==c) ++res;
    for(int i=ls[R+1];i<=r;++i) if(dep[i]==c) ++res;
    return res;
}
int main(){
    scanf("%d",&n);
    for(int i=2;i<=n;++i){
        scanf("%d",&p);
        e[p].PB(i);
    }
    dfs(1,0);
    m=sqrt(n); k=n/m;
    for(int i=1;i*m<=n;++i){
        memset(sum1,0,sizeof sum1);
        for(int j=(i-1)*m+1;j<=i*m;++j) ++sum1[dep[j]];
        for(int j=1;j<=n;++j) sum[i][j]=sum[i-1][j]+sum1[j];
        ls[i]=(i-1)*m+1; rs[i]=i*m;
    }
    if(m*m!=n){
        memset(sum1,0,sizeof sum1);
        for(int j=n/m*m+1;j<=n;++j) ++sum1[dep[j]];
        ++k;
        for(int j=1;j<=n;++j) sum[k][j]=sum[k-1][j]+sum1[j];
        ls[k]=(k-1)*m+1; rs[k]=n;
    }
    ls[k+1]=n+5;
    scanf("%d",&q);
    while(q--){
        scanf("%d%d",&u,&d);
        printf("%d\n",calc(id[u],id[u]+siz[u]-1,d+1));
    }
    return 0;
}
//To wish upon the satellite.

整体看来,这份代码跑的不算快(354ms),码量也有点浪费(语出某位代码长度从未上过 3000 的机房大佬),不过还算容易理解,个人感觉作为分块和树上问题的引入还是可行的。