题解:P11194 [COTS 2021] 县 Županije

· · 题解

下文把县称作颜色,县城称为代表。

首先每个颜色的点集在树上一定是联通的。若不然,不在代表 u 所在连通块的点 vP:u\to v 路径上经过其他颜色的代表一定更近。

同理只需考察相邻颜色的限制。仅考虑两代表为 x,y 的颜色 c_x,c_y 相接的点 u,v,得到 d(u,y)>d(u,x),d(v,x)>d(v,y),即 d(u,x)=d(v,y)。若满足此条件,显然任意点要跨越 (u,v) 去找代表是严格劣的,因此充要条件是对于所有相邻颜色都满足此限制。

把颜色缩出来的树建出来,直接设 f_xxc_x 代表时其子树是否可能合法。处理颜色 c 的时候考虑其颜色儿子的点的贡献,利用点分树知道该颜色哪些点满足条件即可。

这里点分树是单点加单点查,因此复杂度可以是 O(n\log n),代码里是 O(n\log^2n)

#include<bits/stdc++.h>
using namespace std;
const int maxn=2e5+5;
int n,K,col[maxn],U[maxn],V[maxn];
vector<int> e[maxn];
namespace SLPF{
    int fa[maxn],sze[maxn],son[maxn],dep[maxn],to[maxn];
    void dfs1(int u,int f){
        fa[u]=f,sze[u]=1,dep[u]=dep[f]+1;
        for(auto v:e[u]){
            if(v==f)continue;
            dfs1(v,u);sze[u]+=sze[v];
            if(sze[v]>sze[son[u]])son[u]=v;
        }
    }
    void dfs2(int u,int t){
        to[u]=t;
        if(son[u])dfs2(son[u],t);
        for(auto v:e[u])if(v!=fa[u]&&v!=son[u])dfs2(v,v);
    }
    int LCA(int u,int v){
        while(to[u]!=to[v]){
            if(dep[to[u]]<dep[to[v]])swap(u,v);
            u=fa[to[u]];
        }
        return dep[u]<dep[v]?u:v;
    }
    void predo(){dfs1(1,0);dfs2(1,1);}
    int dis(int u,int v){return dep[u]+dep[v]-dep[LCA(u,v)]*2;}
}
using SLPF::dis,SLPF::predo;
struct BIT{
    vector<int> c;int L;
    void predo(int n){L=n;c.resize(L+1);}
    void add(int p,int v){++p;if(p<=L)c[p]+=v;}
    int Ask(int p){++p;return c[p];}
}T1[maxn],T2[maxn];
int vis[maxn];
int dfs(int u,int fa){
    if(vis[u])return 0;
    int res=1;for(auto v:e[u])if(v!=fa)res+=dfs(v,u);
    return res;
}
int Find(int u,int fa,int tot,int &zx){
    if(vis[u])return 0;
    int sum=1,mx=0;
    for(auto v:e[u]){
        if(v==fa)continue;
        int t=Find(v,u,tot,zx);
        sum+=t,mx=max(mx,t);
    }
    if(max(mx,tot-sum)<=tot/2)zx=u;
    return sum;
}
int cd[maxn],Fa[maxn];
vector<int> pt[maxn],cpt;
void DFS(int u,int fa){
    cpt.push_back(u),cd[u]=cd[fa]+1;
    for(auto v:e[u]){
        if(vis[v]||v==fa)continue;
        DFS(v,u);
    }
}
int build(int u,int f){
    Find(u,0,dfs(u,0),u);vis[u]=1;cd[u]=0;
    for(auto v:e[u]){
        if(vis[v])continue;cpt.clear();
        DFS(v,u);for(auto x:cpt)pt[u].push_back(x);
    }pt[u].push_back(u);
    int mx1=0,mx2=0;
    for(auto x:pt[u]){mx1=max(mx1,dis(x,u));if(f)mx2=max(mx2,dis(x,f));}
    T1[u].predo(mx1+1);if(f)T2[u].predo(mx2+1);
    for(auto v:e[u]){
        if(vis[v])continue;
        int x=build(v,u);Fa[x]=u;
    }
    return u;
}
void ADD(int p,int d,int v){
    T1[p].add(d,v);int u=p;
    while(Fa[p]){int k=dis(Fa[p],u);
        if(d>=k)T1[Fa[p]].add(d-k,v),T2[p].add(d-k,v);
        p=Fa[p];
    }
}
int ask(int p){
    int res=T1[p].Ask(0),u=p;
    while(Fa[p]){int d=dis(Fa[p],u);
        res+=T1[Fa[p]].Ask(d)-T2[p].Ask(d);
        p=Fa[p];
    }
    return res;
}
vector<int> colp[maxn];
vector<array<int,4> > T[maxn];
void dfsc(int u){vis[u]=1;
    colp[col[u]].push_back(u);
    for(auto v:e[u]){
        if(vis[v]||col[v]!=col[u])continue;
        dfsc(v);
    }
}
int f[maxn],res[maxn],dv[maxn];
void calc(int u,int cf){
    for(auto [U,v,pu,pv]:T[u]){
        if(v==cf)continue;
        calc(v,u);
    }int cnt=0;
    for(auto [U,v,pu,pv]:T[u]){
        if(v==cf)continue;++cnt;
        for(auto x:colp[v]){if(!f[x])continue;
            int d=dis(x,pv);if(dv[d])continue;
            ADD(pu,d,1);dv[d]=1;
        }
        for(auto x:colp[v])dv[dis(x,pv)]=0;
    }
    for(auto x:colp[u])if(ask(x)==cnt)f[x]=1;
    for(auto [U,v,pu,pv]:T[u]){
        if(v==cf)continue;++cnt;
        for(auto x:colp[v]){if(!f[x])continue;
            int d=dis(x,pv);if(dv[d])continue;
            ADD(pu,d,-1);dv[d]=1;
        }
        for(auto x:colp[v])dv[dis(x,pv)]=0;
    }
}
void dfsa(int u,int d,int crt,int cf){
    int rx=-1;
    if(d==-1){for(auto x:colp[u])if(f[x])rx=x;}
    else for(auto x:colp[u]){if(f[x]&&dis(crt,x)==d)rx=x;}
    res[u]=rx;if(rx==-1)exit(1);
    for(auto [U,v,pu,pv]:T[u])if(v!=cf)dfsa(v,dis(pu,rx),pv,u);
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    cin>>n>>K;for(int i=1;i<=n;i++)cin>>col[i];
    for(int i=1,u,v;i<n;i++)cin>>u>>v,e[u].push_back(v),e[v].push_back(u),U[i]=u,V[i]=v;
    for(int i=1;i<=n;i++){
        if(vis[i])continue;
        if(colp[col[i]].size())return cout<<"NE"<<endl,0;
        dfsc(i);
    }memset(vis,0,sizeof(vis));
    predo();build(1,0);
    for(int i=1;i<n;i++)if(col[U[i]]!=col[V[i]])T[col[U[i]]].push_back({col[U[i]],col[V[i]],U[i],V[i]}),T[col[V[i]]].push_back({col[V[i]],col[U[i]],V[i],U[i]});
    calc(1,0);
    bool flg=0;for(auto x:colp[1])if(f[x])flg=1;
    if(!flg)cout<<"NE"<<endl;
    else {
        cout<<"DA"<<endl;
        dfsa(1,-1,-1,0);
        for(int i=1;i<=K;i++)cout<<res[i]<<" ";cout<<endl;
    }
    return 0;
}