CF1740H Solution

· · 题解

提供一个 O(n\log n) 的做法。

首先一个点的权值是 O(\log n) 级别的,这是因为设 f_k 表示权值为 k 的树的最小大小,那么有 f_k=\sum_{i=0}^{k-1} f_k +1,解得 f_k=2^k

那么只需要用一个数状压下儿子中存不存在权值为 k 的点,每次 lowbit 一下,便可做到 O(1) 动态维护一个点的 \text{MEX} 值。

接下来就是怎么维护权值了,考虑 DDP,树剖之后假设一个点所有轻儿子的权值都知道,那么该点权值关于该点重儿子的权值的函数 g(x) 可以快速复合。进一步地,我们可以在复合的时候维护出复合的前缀和函数,大概是 h(x)=g_1(x)+g_2(g_1(x))+g_3(g_2(g_1(x)))\dots 这种形式。

于是使用全局平衡二叉树维护信息,每次采用上面状压的技巧更新轻儿子。时间复杂度 O(n\log n),目前是 CF 最快解。

为了实现方便,代码中的 \text{MEX} 值实际上是 +1 了的。

#include <cstdio>
using namespace std;
int read(){
    char c=getchar();int x=0;
    while(c<48||c>57) c=getchar();
    do x=(x<<1)+(x<<3)+(c^48),c=getchar();
    while(c>=48&&c<=57);
    return x;
}
const int N=3000013,Bmask=4194302;
struct data{
    int val,a[2],s[2];
    int operator()(const int x)const{return a[x==val];}
    int operator[](const int x)const{return s[x==val];}
    void iden(){val=-1;s[0]=s[1]=a[0]=a[1]=0;}
}arr[N],sm[N];
data operator+(const data x,const data y){
    if(x.val==-1) return y;
    if(y.val==-1) return x;
    data res;
    res.val=x.val;
    res.a[0]=y(x.a[0]);
    res.a[1]=y(x.a[1]);
    res.s[0]=y[x.a[0]]+x.s[0];
    res.s[1]=y[x.a[1]]+x.s[1];
    return res;
}
int cnt[N][22],mex[N];
void addval(int x,int v){if(!cnt[x][v]++) mex[x]^=(1<<v);}
void delval(int x,int v){if(!--cnt[x][v]) mex[x]^=(1<<v);}
void init(int x){
    int lb=__builtin_ctz(mex[x]);
    int llb=__builtin_ctz(mex[x]^(1<<lb));
    arr[x].s[0]=arr[x].a[0]=arr[x].val=lb;
    arr[x].s[1]=arr[x].a[1]=llb;
}
int hd[N],ver[N],nxt[N],tot;
void add(int u,int v){
    nxt[++tot]=hd[u];hd[u]=tot;ver[tot]=v;
}
int ft[N],sn[N],sz[N],res;
int par[N],lc[N],rc[N];
void dfs(int u){
    sz[u]=1;
    for(int i=hd[u];i;i=nxt[i]){
        int v=ver[i];
        dfs(v);
        sz[u]+=sz[v];
        if(sz[sn[u]]<sz[v]) sn[u]=v;
    }
}
int stk[N],tp;
int build(int l,int r){
    if(l==r) return stk[l];
    int tl=l,tr=r,all=sz[stk[l]]+sz[stk[r+1]];
    while(tl<tr){
        int mid=(tl+tr)>>1;
        if(2*sz[stk[mid+1]]<all) tr=mid;
        else tl=mid+1;
    }
    int pos=stk[tl];
    if(l<tl) par[lc[pos]=build(l,tl-1)]=pos;
    if(tr<r) par[rc[pos]=build(tr+1,r)]=pos;
    return pos;
}
int constr(int u){
    for(int p=u;p;p=sn[p])
        for(int i=hd[p];i;i=nxt[i]){
            int v=ver[i];
            if(v==sn[p]) continue;
            par[constr(v)]=p;
        }
    tp=0;
    for(int p=u;p;p=sn[p]) stk[++tp]=p;
    stk[tp+1]=0;
    return build(1,tp);
}
bool nrt(int u){
    return par[u]&&(lc[par[u]]==u||rc[par[u]]==u);
}
void jump(int u){
    init(u);
    while(nrt(u)){
        sm[u]=sm[rc[u]]+arr[u]+sm[lc[u]];
        u=par[u];
    }
    if(~sm[u].val){
        if(par[u]) delval(par[u],sm[u](0));
        res-=sm[u][0];
    }
    sm[u]=sm[rc[u]]+arr[u]+sm[lc[u]];
    res+=sm[u][0];
    if(par[u]){
        addval(par[u],sm[u](0));
        jump(par[u]);
    }
}
int main(){
    int n=read()+1;
    for(int i=2;i<=n;++i) add(ft[i]=read(),i);
    for(int i=1;i<=n;++i) sm[i].iden(),arr[i].iden(),mex[i]=Bmask;
    sm[0].iden();dfs(1);constr(1);
    for(int i=1;i<=n;++i){
        jump(i);
        if(i>1) printf("%d\n",res-i);
    }
    return 0;
}