题解 P5384 【[Cnoi2019]雪松果树】

· · 题解

Update on 2020.5.6

增加优化,保证了能过目前的所有 Hack 数据。

线段树合并是能过这个题的!(并且跑到了最优解第一页)

没错。

诶线段树合并 O(n\log n) 空间 O(n\log n) 时间怎么可能过?

下面我们一一优化。

Step 1:优化空间

首先求 k 级祖先可以离线然后 dfs,用一个栈记录当前节点的所有祖先,然后 O(n) 得到。但是线段树合并呢?

其实,假如合并完一颗子树后马上删除这棵子树的线段树,空间复杂度是 O(n)!但是,条件是:按照 size 从大到小访问所有儿子。

为什么?因为在任意时刻,真正存在的至多只有两棵线段树:

注意我们按照 size 排序了儿子,因此不会出现许多小点增加无意义的空间的情况。现在作者不能证明空间复杂度,若有能证明的 / Hack 的请私信我。

每棵线段树空间都是 O(n),因此总空间就是 O(n)

Step 2:换掉 vector

由于 vector 过于占空间,我们需要用类似于链表的方法将询问串起来。具体可以看代码,实现很简单。

Step 3:常数优化

首先快读快输。10^6 的数据,输入输出优化会有很大作用。

然后线段树的值域最大不用到 n,而是最大的深度。这里也可以稍稍优化一下。同时这个优化也使得空间大大减小,因此暂时无法找到 Hack 数据。

Step 4:代码

AC 代码如下。

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
inline char get_char() {
    static char buf[5000001],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,5000000,stdin),p1==p2)?EOF:*p1++;
}
//#define get_char getchar
inline int read() {
    char ch;
    int x=0,ff=1;
    do {
        ch=get_char();
    } while((ch<'0'||'9'<ch)&&ch!='-');
    if(ch=='-')ff=-1,ch=get_char();
    while('0'<=ch&&ch<='9') {
        x=10*x+ch-'0';
        ch=get_char();
    }
    return x*ff;
}
char buffer[7000005];
int p1=-1;
const int p2=7000000;
inline void flush() {
    fwrite(buffer,1,p1+1,stdout),p1=-1;
}
inline void putc(const char &x){
    if(p1==p2)flush();
    buffer[++p1]=x;
}
void wrtn(int x) {
    if(!x){
        putc('0');
        putc(' ');
        return ;
    } 
    static char buf[15];
    static int len=-1;
    do {
        buf[++len]=x%10+48,x/=10;
    } while(x);
    while(len>=0){
        putc(buf[len]),--len;
    }
    putc(' ');
}
int n,cnt,cntt,yyy,hh[1000005],hhh[1000005],yyyy,h[1000005],root[1000005],ans[1000005],d[1000005],q,st[4000005],top,sta[1000005],top2,size[1000005];
struct TreeNode{
    int ls,rs,sum;
}t[4000005];
struct Edge{
    int to,next;
}e[1000005];
struct Que{
    int k,id,next;
}qq[1000005],qqq[1000005];
int NewNode(){
    if(top)return st[top--];
    return ++cntt;
}
void Add_Edge(int x,int y){
    e[++cnt].to=y;
    e[cnt].next=h[x];
    h[x]=cnt;
}
void AddQ(int x,int k,int id){
    qq[++yyy]={k,id,hh[x]};
    hh[x]=yyy;
}
void Addqq(int x,int k,int id){
    qqq[++yyyy]={k,id,hhh[x]};
    hhh[x]=yyyy;
}
void Update(int &p,int l,int r,int x){
    if(!p)p=NewNode();
    if(l==r){
        t[p].sum++;
        return ;
    }
    int mid=(l+r)/2;
    if(x<=mid)Update(t[p].ls,l,mid,x);
    else Update(t[p].rs,mid+1,r,x);
}
void Merge(int &p,int q,int l,int r){
    if(!p||!q){
        p=p+q;
        return ;
    }
    if(l==r){
        t[p].sum+=t[q].sum;
        t[q].ls=t[q].rs=t[q].sum=0;
        st[++top]=q;
        return ;
    }
    int mid=(l+r)/2;
    Merge(t[p].ls,t[q].ls,l,mid);
    Merge(t[p].rs,t[q].rs,mid+1,r);
    t[q].ls=t[q].rs=t[q].sum=0;
    st[++top]=q;
}
int Query(int p,int l,int r,int x){
    if(l==r)return t[p].sum;
    int mid=(l+r)/2;
    if(x<=mid)return Query(t[p].ls,l,mid,x);
    else return Query(t[p].rs,mid+1,r,x);
}
int maxd=0;
int aa[1000005]={0};
void Solve(int now,int fa){
    sta[++top2]=now;
    for(int i=hh[now];i;i=qq[i].next){
        Que y=qq[i];
        if(top2>y.k)Addqq(sta[top2-y.k],d[now],y.id);
    }
    root[now]=++cntt;
    int from=aa[0];
    for(int i=h[now];i;i=e[i].next)aa[++aa[0]]=e[i].to;
    int to=aa[0];
    sort(aa+from+1,aa+to+1,[](int x,int y){return size[x]>size[y];});
    for(int i=from+1;i<=to;i++){
        int y=aa[i];
        //cout<<now<<' '<<y<<endl;
        Solve(y,now);
        Merge(root[now],root[y],1,maxd);
    }
    Update(root[now],1,maxd,d[now]);
    for(int i=hhh[now];i;i=qqq[i].next){
        Que y=qqq[i];
        ans[y.id]=Query(root[now],1,maxd,y.k);
    }
    top2--;
}
void DFS(int now,int dep){
    maxd=max(maxd,dep);
    d[now]=dep,size[now]=1;
    for(int i=h[now];i;i=e[i].next){
        int y=e[i].to;
        DFS(y,dep+1);
        size[now]+=size[y];
    }
}
int main(){
    n=read(),q=read();
    for(int i=2,x;i<=n;i++){
        x=read();
        Add_Edge(x,i);
    }
    DFS(1,1);
    for(int i=1,x,k;i<=q;i++){
        x=read(),k=read();
        AddQ(x,k,i);
    }
    Solve(1,0);
    for(int i=1;i<=q;i++)wrtn(max(ans[i]-1,0));
    flush();
    return 0;
}