题解 [AGC034E Complete Compress]

· · 题解

\rm I. 前言

听冬令营讲课,只有看到这题之后马上想出了一个大概的解法,遂没听讲解敲代码。但是当时没过,今晚再调了一下就过了。这是我 AC 的第一道集训队作业题,故写一篇题解以纪念。

\rm II. Solution

n 的范围,像是一个 \Theta(n ^ 2) 的做法。

考虑到最后的点只有一个,我们可以枚举它。假设已经固定了节点 r 作为终点,不妨以它为根。

g(u) = \sum\limits_{v \in S_u, col_v = 1} \text{dist}(u, v),这里 S_u 表示以 u 为根的子树,那么一次操作后 g(r) 要么不变,要么减少 2。为了求得最少的次数,我们当然要选择后者。同时经过手玩可以发现,使 g(r) 不变的操作对最后的结果是没有影响的。细想一下,使 g(r) 不变的操作所选的两个点肯定在 r 的同一棵子树内,而没有合法方案的原因肯定是 g(r) 是奇数或者 r 的子树之间的黑点数量太不平衡。

所以我们考虑递归解决这个问题。设 f(u) 为以 u 为根的子树内通过一系列操作 g(u) 能够达到的最小值。由上讨论可知每次 g(u) 减小肯定是由其两棵子树的节点相抵消的,这要求 f(v) 不能太大(vu 的一个儿子)。

具体地,设 sz_u 表示以 u 为根子树内棋子的个数,那么 sz_u = [col_u = 1] + \sum sz_v,其中 vu 的儿子。的如果在操作之前 对于 \forall v \in S_u, f_v + sz_v < g_u - g_v - sz_v,这代表着任意一棵子树在一系列操作后都能不超过总和的一半,所以可以抵消至不能操作,f_u 即为 g(u) 的奇偶性。反之,f_u 即为 f_v + sz_v - (g_u - g_v - sz_v),因为 v 的子树内 g 最小时 ug 才能取到最小。

枚举 1, 2, \cdots, n 为终点,这个问题就被我们解决了, 时间复杂度 \Theta(n^2)

能不能再给力一点啊?

对于这种问题,有一种奇妙的技巧叫做换根 dp,没学过的建议出门左转换根 dp 板子题。

我们能否把它用在这个题目上呢?

对于一个节点 u,设它原来的父亲为 fa。如果以它为根,它的子树就比第一次多了一个,那就是以 fa 为根的不在 S_u 中的点组成的子树,我们只需要处理出这棵树的距离的最大和最小值。而这棵树又可以看作 S_{fa} 被挖去了以 u 为根的子树和添上了除去 S_{fa} 的其他点。所以我们还需要对于每个点 u 预处理 f_v + g_v + 2sz_v 的最大值和次大值。这样的话那么现在转移的方程已经比较显然了,只是细节由有亿点多,感兴趣的读者可以自行推导。 时间复杂度 \Theta(n)

我知道我讲的不清楚,所以读者可以结合代码理解。毕竟这个题目也是我目前为止调的最久的题目之一了。我相信只要你自己想出了这个题的 \Theta(n) 的解法,你对换根 dp 的掌握程度会上一个新台阶。

代码中的变量和数组名与题解中有亿点不同,题解中的 f 数组是代码中的 downf,题解中的 g 是代码中的 downup_u\sum\limits_{v \not\in S_u, col_v = 1}\text{dist}(u, v)f 数组变成了整棵树除去 S_u 中非 u 的节点以 u 为根时的最小距离和,Ans_u 是以 u 为根时最终的答案, sz_u 是初始时 S_u 中棋子的数量,g_u 则变成了 \sum\limits_{col_v = 1} \text{dist}(u, v) = up_u + down_u

#include <bits/stdc++.h>

using namespace std;

#define il inline
#define re register
#define Rep(i, s, e) for (re int i = s; i <= e; ++i)
#define Dep(i, s, e) for (re int i = s; i >= e; --i)
#define file(a) freopen(#a".in", "r", stdin), freopen(#a".out", "w", stdout)

const int N = 2000010;

il int read() {
    int x = 0; bool f = true; char c = getchar();
    while (!isdigit(c)) {if (c == '-') f = false; c = getchar();}
    while (isdigit(c)) x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return f ? x : -x;
}

int n;
char s[N];
bool hav[N];

int head[N], tot;
struct node {int nxt, to;} edge[N << 1];
il void adde(int u, int v) {edge[++tot].nxt = head[u], edge[tot].to = v, head[u] = tot;}

int dep[N], down[N], now, downf[N], sz[N], nmax[N][3], g[N], up[N], ans = 0x3f3f3f3f, f[N], Ans[N];

il int upd(int u, int x) {
    if (x > nmax[u][1]) nmax[u][2] = nmax[u][1], nmax[u][1] = x;
    else if (x > nmax[u][2]) nmax[u][2] = x;
}

il int get(int u) {
    return down[u] + downf[u] + 2 * sz[u];
}

void dfs_down(int u, int fa) {
    dep[u] = dep[fa] + 1, sz[u] = hav[u]; int nowmax = 0;
    for (re int e = head[u]; e; e = edge[e].nxt) {
        int v = edge[e].to;
        if (v == fa) continue;
        dfs_down(v, u), down[u] += down[v] + sz[v], sz[u] += sz[v], upd(u, get(v));
    }
    downf[u] = nmax[u][1] > down[u] ? nmax[u][1] - down[u] : down[u] & 1;
}

void dp(int u, int fa) {
    if (u != 1) {
        up[u] = down[fa] - down[u] - sz[u] - sz[u] + up[fa] + sz[1], g[u] = up[u] + down[u];
        int nowmax = max(get(u) == nmax[fa][1] ? nmax[fa][2] : nmax[fa][1], f[fa] + up[fa]), nowg = g[fa] - down[u] - sz[u];
        f[u] = nowmax > nowg ? nowmax - nowg : nowg & 1;
        f[u] += sz[1] - sz[u];
        nowg = g[u], nowmax = max(nmax[u][1], up[u] + f[u]);
        Ans[u] = nowmax > nowg ? nowmax - nowg : nowg & 1;
    }
    for (re int e = head[u]; e; e = edge[e].nxt) {
        int v = edge[e].to;
        if (v == fa) continue;
        dp(v, u);
    }
}

int main() {
    n = read(), dep[0] = -1;
    scanf("%s", s + 1);
    Rep(i, 1, n) hav[i] = (s[i] == '1');
    Rep(i, 2, n) {
        int u = read(), v = read();
        adde(u, v), adde(v, u);
    }
    dfs_down(1, 0);
    if (!downf[1]) ans = down[1] >> 1;
    g[1] = down[1];
    dp(1, 0);
    Rep(i, 2, n) {
        if (!Ans[i]) ans = min(ans, g[i] >> 1);
    } 
    printf("%d\n", ans == 0x3f3f3f3f ? -1 : ans);
    return 0;
}