题解:AT_abc437_g [ABC437G] Colorful Christmas Tree

· · 题解

更好的阅读体验

打 abc 打到的质量最高的一道题,感觉学到了很多东西。

树是一个二分图。考虑 flow。

\texttt{RGB} 映射成数字。我们把树按照深度分成左部点和右部点。

我们考虑把一个点拆成三个,左部点 u 拆成 lp_{u, 0/1/2},右部点 u 拆成 rp_{u, 0/1/2},分别表示一个点的状态。经过一个点 lp/rp_{u, 0/1/2} 的流量,表示 u 处于当前状态 0/1/2 时,被操作的次数。

然后容易发现一个点被操作的次数是恒定的,即这个点的度数。由于我们的变换规则也是确定的,所以对于一个点 u 和一个状态,位于这个状态时他会被操作多少次也是确定的。假设 u 位于 j = 0/1/2 时被操作的次数为 w_{u, j},这个 w 是很好求的。

我们需要每个状态的 w_{u, j} 都要被流满。所以我们可以通过从源点连一条边,或者连一条边去汇点来限制通过这个点的流量。我们建边:

然后我们考虑添加一条树边。我们枚举删除这条树边时左右两个点的状态(共 6 种),分别在这六种情况中间连容量为 1 的边,表示进行一次操作,导走 1 的容量。所以假设树边 (u, v)u 在左部,v 在右部,那么建边:

我们要求每条边都被删完,也就是要每个点都要流满,所以如果最后最大流不是 n-1 说明无解。

考虑构造。构造就很简单了,我们首先可以通过看残量网络上每条边有没有被流过,得知一条树边被删除的时候,两个端点分别是什么状态。我们重复 n-1 次,每次选择一个两个端点均符合状态的边删除即可,同时修改这两个点的状态。

那么这道题就做完了,复杂度 O(\sum n^2)

#include<bits/stdc++.h>
#define endl '\n'
#define N 6006
#define M 500006
using namespace std;
int n,tot,vis[N],a[N],dep[N],lp[N][3],rp[N][3],eid[N][N]; char ch[N];
vector<int> G[N];
struct E {int u,v,cu,cv;} e[N];
struct MF_Graph { //by dyc2022
    int head[N],cnt=2,s,t,now[N],dep[N];
    struct Edge {int to,next,w;} E[M];
    void init()
    {
        for(int i=1;i<=cnt;i++)E[i]={0,0,0};
        for(int i=1;i<=n*5;i++)head[i]=now[i]=dep[i]=0;
        s=t=0,cnt=2;
    }
    void addedge(int u,int v,int w){E[cnt]={v,head[u],w},head[u]=cnt++;}
    void addflow(int u,int v,int w){addedge(u,v,w),addedge(v,u,0);}
    int bfs()
    {
        queue<int> q;
        memset(dep,-1,sizeof(dep));
        dep[s]=0,now[s]=head[s],q.push(s);
        while(q.size())
        {
            int u=q.front(); q.pop();
            for(int i=head[u],v,w;i;i=E[i].next)
            {
                v=E[i].to,w=E[i].w;
                if(dep[v]==-1&&w>0)
                {
                    dep[v]=dep[u]+1,now[v]=head[v],q.push(v);
                    if(v==t)return 1;
                }
            }
        }
        return 0;
    }
    int dfs(int u,int fl)
    {
        if(u==t)return fl;
        int ret=0;
        for(int i=now[u],v,w;i&&fl>0;i=E[i].next)
        {
            v=E[i].to,w=E[i].w,now[u]=i;
            if(dep[v]==dep[u]+1&&w>0)
            {
                int tmp=dfs(v,min(fl,w)); if(!tmp)dep[v]=-1;
                E[i].w-=tmp,E[i^1].w+=tmp,fl-=tmp,ret+=tmp;
            }
        }
        return ret;
    }
    int getflow(){int ret=0; while(bfs())ret+=dfs(s,2e9); return ret;}
} mf;
void dfs(int u,int fa)
{
    dep[u]=dep[fa]+1;
    for(int v:G[u])if(v!=fa)dfs(v,u);
}
inline int trans(char c){return c=='R'?0:(c=='G'?1:2);}
void solve()
{
    scanf("%d%s",&n,ch+1),mf.init(),tot=0;
    for(int i=1;i<=n;i++)G[i].clear(),a[i]=trans(ch[i]),vis[i]=0;
    for(int i=1;i<n;i++)
        scanf("%d%d",&e[i].u,&e[i].v),
        G[e[i].u].push_back(e[i].v),G[e[i].v].push_back(e[i].u);
    dfs(1,0),mf.s=++tot,mf.t=++tot;
    for(int i=1;i<=n;i++)if(dep[i]&1)
    {
        for(int j=0;j<3;j++)lp[i][j]=++tot;
        int sz=G[i].size(),tmp=a[i],w[3]={0,0,0};
        while(sz--)w[tmp]++,tmp=(tmp+1)%3;
        for(int j=0;j<3;j++)mf.addflow(mf.s,lp[i][j],w[j]);
    } else {
        for(int j=0;j<3;j++)rp[i][j]=++tot;
        int sz=G[i].size(),tmp=a[i],w[3]={0,0,0};
        while(sz--)w[tmp]++,tmp=(tmp+1)%3;
        for(int j=0;j<3;j++)mf.addflow(rp[i][j],mf.t,w[j]);
    }
    for(int i=1;i<n;i++)
    {
        int u=e[i].u,v=e[i].v; if(dep[v]&1)swap(u,v),swap(e[i].u,e[i].v);
        for(int j=0;j<3;j++)
            for(int k=0;k<3;k++)if(j!=k)
                mf.addflow(lp[u][j],rp[v][k],1),eid[lp[u][j]][rp[v][k]]=mf.cnt-1;
    }
    int maxf=mf.getflow();
    if(maxf<n-1)return printf("No\n"),(void)0;
    printf("Yes\n");
    for(int i=1;i<n;i++)
    {
        int u=e[i].u,v=e[i].v;
        for(int j=0;j<3;j++)
            for(int k=0;k<3;k++)if(j!=k)
            {
                int x=lp[u][j],y=rp[v][k];
                if(mf.E[eid[x][y]].w)e[i].cu=j,e[i].cv=k;
            }
    } vector<int> ans;
    for(int i=1;i<n;i++)
        for(int j=1;j<n;j++)if(!vis[j])
        {
            int u=e[j].u,v=e[j].v,cu=e[j].cu,cv=e[j].cv;
            if(a[u]==cu&&a[v]==cv)
            {
                ans.push_back(j),a[u]++,a[v]++;
                a[u]%=3,a[v]%=3,vis[j]=1; break;
            }
        }
    for(int i:ans)printf("%d ",i); putchar(10);
}
main()
{
    int T; scanf("%d",&T);
    while(T--)solve();
    return 0;
}