题解:AT_abc437_g [ABC437G] Colorful Christmas Tree

· · 题解

坏了,比赛当天上午我网络流的算法笔记刚过审,晚上就用上了。(忽略这一篇文章发布的日期)

这道题当然可以树形 dp,但很难构造边。

我们可以把点按照深度的奇偶性分开,这样就是一个二分图。

对于一个节点,它在三个颜色的情况下连的边会被删多少次是可以算出来的,所以在这个新图中每个点都对应了三个点,分贝是三个颜色,我们记点 i 的第 j 个颜色的次数为 a_{i,j}

那么就常规网络流操作,我们建立超级源汇点,源点向深度为奇数的点都连这个点对应的 a_{i,j},汇点同理。然后对于树上的每一条边,我们都把这条边连接的两个点任意两个不一样的颜色连边,也就是说每个树上的边会带来 6 条边。

这个时候跑最大流,那么要是最大流不是 n - 1 就无解了,否则有解。

接下来差一步构造,只需要在流完的图判断两个节点对应的边有没有流量即可。

#include<bits/stdc++.h>
using namespace std;
#define int long long
struct flowGraph
{
    struct edge {
        int v, nxt, cap, flow;
    } edges[250000]; 

    int head[222222], cnt = 0; 
    int n, S, T;
    int maxflow = 0;
    int dist[222222], cur[222222];
    void init(int _n, int _S, int _T) {

        for (int i = 0; i <= _n; i++) head[i] = -1;
        cnt = 0;
        S = _S;
        T = _T;
        n = _n;
    }
    void addedge(int u, int v, int w) {
        edges[cnt] = {v, head[u], w, 0};
        head[u] = cnt++;
        edges[cnt] = {u, head[v], 0, 0};
        head[v] = cnt++;
    }

    bool bfs() {

        for (int i = 0; i <= n; i++) dist[i] = 0;
        queue<int> q;
        dist[S] = 1;
        q.push(S);
        while (q.size()) {
            int u = q.front();
            q.pop();
            for (int i = head[u]; ~i; i = edges[i].nxt) {
                int v = edges[i].v;
                if ((!dist[v]) && (edges[i].cap > edges[i].flow)) {
                    dist[v] = dist[u] + 1;
                    q.push(v);
                }
            }
        }
        return dist[T];
    }

    int dfs(int u, int flow) {
        if ((u == T) || (!flow)) {
            return flow;
        }
        int res = 0;
        for (int &i = cur[u]; ~i; i = edges[i].nxt) {
            int v = edges[i].v, f;
            if ((dist[v] == dist[u] + 1) && (f = dfs(v, min(flow - res, edges[i].cap - edges[i].flow)))) {
                res += f;
                edges[i].flow += f;
                edges[i ^ 1].flow -= f;
                if (res == flow) {
                    return res;
                }
            }
        }
        return res;
    }
    int dinic() {
        maxflow = 0;
        while (bfs()) {
            for (int i = 0; i <= n; i++) cur[i] = head[i];
            maxflow += dfs(S, 1e18); 
        }
        return maxflow;
    }
} g;
vector<int> edges[2005];
map<char, int> mp;
int idx(int x, int c)
{
    return x * 3 + c + 2;
}
void solve()
{
    int n;
    cin >> n;
    vector<char> c(n + 1);
    for (int i = 1; i <= n; i++) {
        cin >> c[i];
    }
    for (int i = 1; i <= n; i++) {
        edges[i].clear();
    }
    g.init(n * 3 + 10, 0, 1);
    vector<int> deg(n + 1, 0);
    vector<pair<int, int>> edge(n - 1);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        edges[u].push_back(v);
        edges[v].push_back(u);
        edge[i] = {u, v};
        deg[u]++;
        deg[v]++;
    }
    vector<vector<int>> a(n + 1, vector<int>(3));
    for (int i = 1; i <= n; i++) {
        a[i][mp[c[i]]] = deg[i] / 3;
        a[i][(mp[c[i]] + 1) % 3] = (deg[i] + 2) / 3;
        a[i][(mp[c[i]] + 2) % 3] = (deg[i] + 1) / 3;
        a[i].assign(a[i].size(), 0);
        for(int k = 0; k < deg[i]; k++) {
            a[i][(mp[c[i]] + k) % 3]++;
        }
    }
    vector<int> dep(n + 1, 0);
    function<void(int, int)> dfs = [&](int x, int fa) -> void {
        for (auto nex : edges[x]) {
            if (nex != fa) {
                dep[nex] = dep[x] + 1;
                dfs(nex, x);
            }
        }
    };
    dep[1] = 1;
    dfs(1, 0);
    vector<int> v1, v2; 
    for (int i = 1; i <= n; i++) {
        if (dep[i] % 2 != 0) {
            v1.push_back(i);
        } else {
            v2.push_back(i);
        }
    }
    for (int i = 0; i < v1.size(); i++) {
        for (int j = 0; j < 3; j++) {
            if (a[v1[i]][j] > 0) {
                g.addedge(0, idx(v1[i], j), a[v1[i]][j]);
            }
        }
    }
    for (int i = 0; i < v2.size(); i++) {
        for (int j = 0; j < 3; j++) {
            if (a[v2[i]][j] > 0) {
                g.addedge(idx(v2[i], j), 1, a[v2[i]][j]);
            }
        }
    }
    vector<vector<int>> eidx(n - 1, vector<int>(9));
    for (int i = 0; i < n - 1; i++) {
        int u = edge[i].first;
        int v = edge[i].second;
        if (dep[u] % 2 == 0) {
            swap(u, v);
        }
        int id = 0;
        for (int j = 0; j < 3; j++) {
            for (int k = 0; k < 3; k++) {
                if (j == k) continue; 
                eidx[i][id] = g.cnt;
                g.addedge(idx(u, j), idx(v, k), 1);
                id++;
            }
        }
    }
    if (g.dinic() != n - 1) {
        cout << "No" << '\n';
        return;
    }
    cout << "Yes" << '\n';
    vector<int> r1(n - 1), r2(n - 1);
    for (int i = 0; i < n - 1; i++) {
        int u = edge[i].first;
        int v = edge[i].second;
        bool flag = 0;
        if (dep[u] % 2 == 0) {
            swap(u, v);
            flag = 1;
        }
        int id = 0;
        int tp1 = -1, tp2 = -1;
        for (int j = 0; j < 3; j++) {
            for (int k = 0; k < 3; k++) {
                if (j == k) {
                    continue;
                }
                if (g.edges[eidx[i][id]].flow == 1) {
                    tp1 = j;
                    tp2 = k;
                }
                id++;
            }
        }
        if (flag) {
            r1[i] = tp2; 
            r2[i] = tp1;
        } else {
            r1[i] = tp1;
            r2[i] = tp2;
        }
    }
    vector<int> col(n + 1);
    for (int i = 1; i <= n; i++) {
        col[i] = mp[c[i]];
    }
    vector<bool> vis(n - 1, 0);
    vector<int> ans;
    for (int i = 0; i < n - 1; i++) {
        for (int j = 0; j < n - 1; j++) {
            if (vis[j]) {
                continue;
            }
            int u = edge[j].first;
            int v = edge[j].second;
            if (col[u] == r1[j] && col[v] == r2[j]) {
                ans.push_back(j + 1);
                vis[j] = 1;
                col[u] = (col[u] + 1) % 3;
                col[v] = (col[v] + 1) % 3;
                break;
            }
        }
    }
    for (int i = 0; i < ans.size(); i++) {
        cout << ans[i] << ' ';
    }
    cout << '\n';
}
signed main()
{
    mp['R'] = 0;
    mp['G'] = 1;
    mp['B'] = 2;
    int T;
    cin >> T;
    while (T--) {
        solve();
    }
    return 0;
}