题解:AT_abc437_g [ABC437G] Colorful Christmas Tree
坏了,比赛当天上午我网络流的算法笔记刚过审,晚上就用上了。(忽略这一篇文章发布的日期)
这道题当然可以树形 dp,但很难构造边。
我们可以把点按照深度的奇偶性分开,这样就是一个二分图。
对于一个节点,它在三个颜色的情况下连的边会被删多少次是可以算出来的,所以在这个新图中每个点都对应了三个点,分贝是三个颜色,我们记点
那么就常规网络流操作,我们建立超级源汇点,源点向深度为奇数的点都连这个点对应的
这个时候跑最大流,那么要是最大流不是
接下来差一步构造,只需要在流完的图判断两个节点对应的边有没有流量即可。
#include<bits/stdc++.h>
using namespace std;
#define int long long
struct flowGraph
{
struct edge {
int v, nxt, cap, flow;
} edges[250000];
int head[222222], cnt = 0;
int n, S, T;
int maxflow = 0;
int dist[222222], cur[222222];
void init(int _n, int _S, int _T) {
for (int i = 0; i <= _n; i++) head[i] = -1;
cnt = 0;
S = _S;
T = _T;
n = _n;
}
void addedge(int u, int v, int w) {
edges[cnt] = {v, head[u], w, 0};
head[u] = cnt++;
edges[cnt] = {u, head[v], 0, 0};
head[v] = cnt++;
}
bool bfs() {
for (int i = 0; i <= n; i++) dist[i] = 0;
queue<int> q;
dist[S] = 1;
q.push(S);
while (q.size()) {
int u = q.front();
q.pop();
for (int i = head[u]; ~i; i = edges[i].nxt) {
int v = edges[i].v;
if ((!dist[v]) && (edges[i].cap > edges[i].flow)) {
dist[v] = dist[u] + 1;
q.push(v);
}
}
}
return dist[T];
}
int dfs(int u, int flow) {
if ((u == T) || (!flow)) {
return flow;
}
int res = 0;
for (int &i = cur[u]; ~i; i = edges[i].nxt) {
int v = edges[i].v, f;
if ((dist[v] == dist[u] + 1) && (f = dfs(v, min(flow - res, edges[i].cap - edges[i].flow)))) {
res += f;
edges[i].flow += f;
edges[i ^ 1].flow -= f;
if (res == flow) {
return res;
}
}
}
return res;
}
int dinic() {
maxflow = 0;
while (bfs()) {
for (int i = 0; i <= n; i++) cur[i] = head[i];
maxflow += dfs(S, 1e18);
}
return maxflow;
}
} g;
vector<int> edges[2005];
map<char, int> mp;
int idx(int x, int c)
{
return x * 3 + c + 2;
}
void solve()
{
int n;
cin >> n;
vector<char> c(n + 1);
for (int i = 1; i <= n; i++) {
cin >> c[i];
}
for (int i = 1; i <= n; i++) {
edges[i].clear();
}
g.init(n * 3 + 10, 0, 1);
vector<int> deg(n + 1, 0);
vector<pair<int, int>> edge(n - 1);
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
edges[u].push_back(v);
edges[v].push_back(u);
edge[i] = {u, v};
deg[u]++;
deg[v]++;
}
vector<vector<int>> a(n + 1, vector<int>(3));
for (int i = 1; i <= n; i++) {
a[i][mp[c[i]]] = deg[i] / 3;
a[i][(mp[c[i]] + 1) % 3] = (deg[i] + 2) / 3;
a[i][(mp[c[i]] + 2) % 3] = (deg[i] + 1) / 3;
a[i].assign(a[i].size(), 0);
for(int k = 0; k < deg[i]; k++) {
a[i][(mp[c[i]] + k) % 3]++;
}
}
vector<int> dep(n + 1, 0);
function<void(int, int)> dfs = [&](int x, int fa) -> void {
for (auto nex : edges[x]) {
if (nex != fa) {
dep[nex] = dep[x] + 1;
dfs(nex, x);
}
}
};
dep[1] = 1;
dfs(1, 0);
vector<int> v1, v2;
for (int i = 1; i <= n; i++) {
if (dep[i] % 2 != 0) {
v1.push_back(i);
} else {
v2.push_back(i);
}
}
for (int i = 0; i < v1.size(); i++) {
for (int j = 0; j < 3; j++) {
if (a[v1[i]][j] > 0) {
g.addedge(0, idx(v1[i], j), a[v1[i]][j]);
}
}
}
for (int i = 0; i < v2.size(); i++) {
for (int j = 0; j < 3; j++) {
if (a[v2[i]][j] > 0) {
g.addedge(idx(v2[i], j), 1, a[v2[i]][j]);
}
}
}
vector<vector<int>> eidx(n - 1, vector<int>(9));
for (int i = 0; i < n - 1; i++) {
int u = edge[i].first;
int v = edge[i].second;
if (dep[u] % 2 == 0) {
swap(u, v);
}
int id = 0;
for (int j = 0; j < 3; j++) {
for (int k = 0; k < 3; k++) {
if (j == k) continue;
eidx[i][id] = g.cnt;
g.addedge(idx(u, j), idx(v, k), 1);
id++;
}
}
}
if (g.dinic() != n - 1) {
cout << "No" << '\n';
return;
}
cout << "Yes" << '\n';
vector<int> r1(n - 1), r2(n - 1);
for (int i = 0; i < n - 1; i++) {
int u = edge[i].first;
int v = edge[i].second;
bool flag = 0;
if (dep[u] % 2 == 0) {
swap(u, v);
flag = 1;
}
int id = 0;
int tp1 = -1, tp2 = -1;
for (int j = 0; j < 3; j++) {
for (int k = 0; k < 3; k++) {
if (j == k) {
continue;
}
if (g.edges[eidx[i][id]].flow == 1) {
tp1 = j;
tp2 = k;
}
id++;
}
}
if (flag) {
r1[i] = tp2;
r2[i] = tp1;
} else {
r1[i] = tp1;
r2[i] = tp2;
}
}
vector<int> col(n + 1);
for (int i = 1; i <= n; i++) {
col[i] = mp[c[i]];
}
vector<bool> vis(n - 1, 0);
vector<int> ans;
for (int i = 0; i < n - 1; i++) {
for (int j = 0; j < n - 1; j++) {
if (vis[j]) {
continue;
}
int u = edge[j].first;
int v = edge[j].second;
if (col[u] == r1[j] && col[v] == r2[j]) {
ans.push_back(j + 1);
vis[j] = 1;
col[u] = (col[u] + 1) % 3;
col[v] = (col[v] + 1) % 3;
break;
}
}
}
for (int i = 0; i < ans.size(); i++) {
cout << ans[i] << ' ';
}
cout << '\n';
}
signed main()
{
mp['R'] = 0;
mp['G'] = 1;
mp['B'] = 2;
int T;
cin >> T;
while (T--) {
solve();
}
return 0;
}