P4006 小 Y 和二叉树

· · 题解

感觉非常困难啊,为什么是蓝/yun

先考虑定根怎么做。

考虑从上往下确定树的形态。设 f_{x,i} 表示以 x 为根且去除第 i 个分支后能取到的最小节点(转移的话显然对 k_y<3y\min 即可),那么两个儿子哪个更优只需要比对应的 f 值;若只有一个儿子,将 f 值和父节点编号比较即可。

然后考虑根是什么。

显然对于 k_x<3 的节点,都存在一种树的形态使其可以被第一个取到。也就是说,字典序最小时的第一个节点就是满足 k_x<3 的最小的 x。考虑如何通过这个 x 判断根是什么。

通过以上方式,可以将根的数量控制在 1/2 个内,可以接受。

总时间复杂度 O(n)

#include<bits/stdc++.h>
#define il inline
using namespace std;
const int inf=1<<30;
const int maxn=1000010;
il int read(){
    int x=0;
    char c=getchar();
    for(;!(c>='0'&&c<='9');c=getchar());
    for(;c>='0'&&c<='9';c=getchar())
        x=(x<<1)+(x<<3)+c-'0';
    return x;
}
void chkmin(int &x,int y){if(y<x)x=y;}
int f[maxn][5];
int n,cnt,a[maxn][5],b[maxn][5];
int m[maxn],c[maxn],ans[maxn];
void renew(){
    if(ans[1]==0){
        for(int i=1;i<=n;i++) ans[i]=c[i];
        return ;
    }
    for(int i=1;i<=n;i++)
        if(c[i]>ans[i]) return ;
        else if(c[i]<ans[i]){
            for(int j=1;j<=n;j++) ans[j]=c[j];
            return ;
        }
}
int ls[maxn],rs[maxn];
void mtree(int fa,int x){
    if(m[x]==1) return ;
    if(m[x]==2){
        int d=a[x][1],t=b[x][1];
        if(d==fa) d=a[x][2],t=b[x][2];
        if(x<f[d][t]) rs[x]=d,mtree(x,d);
        else ls[x]=d,mtree(x,d);
    }
    if(m[x]==3){
        int d1=a[x][1],t1=b[x][1];
        int d2=a[x][2],t2=b[x][2];
        if(d1==fa) d1=a[x][3],t1=b[x][3];
        if(d2==fa) d2=a[x][3],t2=b[x][3];
        if(f[d1][t1]<f[d2][t2]) ls[x]=d1,rs[x]=d2,mtree(x,d1),mtree(x,d2);
        else ls[x]=d2,rs[x]=d1,mtree(x,d2),mtree(x,d1);
    }
}
void Print(int x){
    if(!x) return ;
    Print(ls[x]),c[++cnt]=x,Print(rs[x]);
}
void sol(int rt){
    //printf("rt = %d:\n",rt);
    memset(ls,0,sizeof(ls));
    memset(rs,0,sizeof(rs));
//  printf("([%d,%d],%d)\n",a[rt][1],b[rt][1],f[a[rt][1]][b[rt][1]]);
    if(m[rt]==1){
        if(rt<f[a[rt][1]][b[rt][1]]) 
            rs[rt]=a[rt][1],mtree(rt,a[rt][1]);
        else ls[rt]=a[rt][1],mtree(rt,a[rt][1]);
    }else{
        int d1=a[rt][1],t1=b[rt][1];
        int d2=a[rt][2],t2=b[rt][2];
        if(f[d1][t1]<f[d2][t2]) ls[rt]=d1,rs[rt]=d2,mtree(rt,d1),mtree(rt,d2);
        else ls[rt]=d2,rs[rt]=d1,mtree(rt,d2),mtree(rt,d1);
    }
    cnt=0,Print(rt);
    //for(int i=1;i<=n;i++)
    //  printf("(%d,%d)\n",ls[i],rs[i]);
    //for(int i=1;i<=n;i++) printf("%d ",c[i]);
    //printf("\n");
}
void dfs1(int fa,int x,int lw=0){
    int Mn=inf;
    for(int i=1;i<=m[x];i++)
        if(a[x][i]!=fa){
            dfs1(x,a[x][i],i);
            chkmin(Mn,f[a[x][i]][b[x][i]]);
        }else if(fa) b[fa][lw]=i,b[x][i]=lw;
    f[x][b[fa][lw]]=Mn;
    if(m[x]<3) chkmin(f[x][b[fa][lw]],x);
    /////////////
}
void dfs2(int fa,int x){
    for(int i=1;i<=m[x];i++)
        if(a[x][i]!=fa){
            f[x][i]=inf;
            for(int j=1;j<=m[x];j++)
                if(j!=i) chkmin(f[x][i],f[a[x][j]][b[x][j]]);
        }
    if(m[x]<3) for(int i=1;i<=m[x];i++) chkmin(f[x][i],x);
    for(int i=1;i<=m[x];i++)
        if(a[x][i]!=fa) dfs2(x,a[x][i]);
}
int main(){
    n=read(); 
    for(int i=1;i<=n;i++){
        m[i]=read();
        for(int j=1;j<=m[i];j++)
            a[i][j]=read();
        sort(a[i]+1,a[i]+1+m[i]);
    }dfs1(0,1),dfs2(0,1);
    //for(int i=1;i<=n;i++)
    //printf("id%d->(%d,%d,%d)\n",i,f[i][1],f[i][2],f[i][3]);
    for(int i=1;i<=n;i++)
        if(m[i]!=3){
            if(m[i]==1) sol(i),renew();
            int lst=i,x=i,tmp;
            if(m[x]==1) x=a[x][1];
            else{
                int d1=a[x][1],d2=a[x][2];
                int t1=b[x][1],t2=b[x][2];
                f[d1][t1]<f[d2][t2]?x=d2:x=d1;
            }
            while(1){
                tmp=x;
                if(m[x]==1||m[x]==2){
                    sol(x),renew();
                    break;
                }
                else{
                    int d1=a[x][1],d2=a[x][2];
                    int t1=b[x][1],t2=b[x][2];
                    if(d1==lst) d1=a[x][3],t1=b[x][3];
                    else if(d2==lst) d2=a[x][3],t2=b[x][3];
                    f[d1][t1]<f[d2][t2]?x=d2:x=d1;
                }lst=tmp;
            }
            break;
        }
    for(int i=1;i<=n;i++) printf("%d ",ans[i]);
    printf("\n");
    return 0;
}