洛谷P8127 [BalticOI 2021 Day2] The Xana coup 题解 树形DP

· · 题解

题目链接:https://www.luogu.com.cn/problem/P8127

题目大意:给定一棵包含 n 个节点的树,树上每个点都有一个权值,节点 i 的权值为 a_i(a_i \in \{0,1\})。每次可以选择树上一个点,将这个点以及与其相邻的所有点的权值取反(0 变成 11 变成 0)。问:最少需要几次操作能够让树上所有点的权值都变为 0

我的思路:

首先是树形DP,每个节点 u 对应 4 个状态:

上述所有状态(即 f_{u,0},f_{u,1},f_{u,2},f_{u,3})因为对于节点 u 的非子孙节点来说,它们是没有办法修改节点 u 的子孙节点的状态的,所以所有的 f_{u,i}(0 \le i \le 3) 对应的状态还包含的一个信息是 —— 节点 u 的所有子孙节点当前的权值都为 0

同时,这些操作都不考虑父节点的影响(因为我这里的状态都是根据子节点的状态推导当前节点的状态)。

除此之外,我用 f_{u,i} = +\infty 来表示状态 f_{u,i} 不合法。

然后就可以推导所有的状态了。

叶子节点

对于叶子节点 u 来说:

非叶子节点

对于当前节点 u 来说,只有可能操作或者不操作。

但是这里有一个需要思考的问题,就是:当前节点 u 的状态受子节点中节点的状态的影响!

影响主要在于 —— 子节点中操作过的点是奇数个还是偶数个。

所以可以额外定义四个状态:g_{i,0}, g_{i,1}, h_{i,0}, h_{i,1},这四个状态是对于当前节点 u 来说的。对于当前节点 u,设其有 m 个子节点,则:

首先:

其次,当 i \gt 0 时,(设节点 u 的第 i 个子节点为 v,则)有:

当然我们需要的只有最终计算出来的 g_{m,0}, g_{m,1}, h_{m,0}, h_{m,1} 这四个状态(其中 m 是当前节点 u 的子节点个数)。

此时再分别讨论 a_u1 还是 0 的情况。

a_u = 1 时:

a_u = 0 时:

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
const int INF = (1<<29);
int n, a[maxn], f[maxn][4], g[maxn][2], h[maxn][2];
vector<int> e[maxn];

/**
f[u][0] 点权为 0,点未操作
f[u][1] 点权为 1,点未操作
f[u][2] 点权为 0,点有操作
f[u][3] 点权为 1,点有操作
*/

void dfs(int u, int p) {
    int sz = e[u].size();
    if (sz == 1 && u != 1) {    // 叶子节点
        if (a[u] == 1) {
            f[u][0] = INF;
            f[u][1] = 0;
            f[u][2] = 1;
            f[u][3] = INF;
        }
        else {  // a[u] == 0
            f[u][0] = 0;
            f[u][1] = INF;
            f[u][2] = INF;
            f[u][3] = 1;
        }
        return;
    }
    vector<int> tmp;
    for (auto v : e[u])
        if (v != p)
            dfs(v, u), tmp.push_back(v);
    int m = tmp.size();
    g[0][0] = h[0][0] = 0;
    g[0][1] = h[0][1] = INF;
    for (int i = 1; i <= m; i++) {
        int v = tmp[i-1];
        g[i][0] = min(INF, min(g[i-1][0] + f[v][0], g[i-1][1] + f[v][2]));
        g[i][1] = min(INF, min(g[i-1][0] + f[v][2], g[i-1][1] + f[v][0]));
        h[i][0] = min(INF, min(h[i-1][0] + f[v][1], h[i-1][1] + f[v][3]));
        h[i][1] = min(INF, min(h[i-1][0] + f[v][3], h[i-1][1] + f[v][1]));
    }
    if (a[u] == 1) {
        f[u][0] = g[m][1];
        f[u][1] = g[m][0];
        f[u][2] = min(INF, 1 + h[m][0]);
        f[u][3] = min(INF, 1 + h[m][1]);
    }
    else {
        f[u][0] = g[m][0];
        f[u][1] = g[m][1];
        f[u][2] = min(INF, 1 + h[m][1]);
        f[u][3] = min(INF, 1 + h[m][0]);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        e[u].push_back(v);
        e[v].push_back(u);
    }
    for (int i = 1; i <= n; i++) scanf("%d", a+i);
    dfs(1, -1);
    int ans = min(f[1][0], f[1][2]);
    if (ans == INF) puts("impossible");
    else printf("%d\n", ans);
    return 0;
}