P8916题解

· · 题解

题目大意

给你一棵树,让你从下往上黑白染色,并计算这棵树的最大战力值。

对于战力值,初始为点权,如染色中相邻两点颜色相同,父亲节点的战力值需加上子节点和子节点中所有与其颜色相同结点的战力值。

题目分析

暴力很好想,枚举每种不同的染色方式,然后计算它的战力值总和。

时间复杂度 \mathcal O(2 ^ n \times n)。然后可以得到 20 分的好成绩。

看到题目中是一棵树,而且父亲节点的值是由子节点得出的,于是很好想到是一道树形 dp。

我们按照一般的树形 dp 来想,我们可以定义 f_{i, 0/1} 表示第 i 个点选或不选时的最大战力值。

然后我们来想如何转移,但是不幸的是我们转移 f_{i} 的时候需要知道他的子树中有哪些点与其颜色相同,很明显我们在 dp 状态中无法将其呈现。

于是我们换个思路来想,我们发现,对于每一次“合并”,父节点 x 的战力值需要加上子节点 y 为根的子树下所有颜色与之相同的点的战力值,我们反过来想,对于父节点 x,它的战力值在它之上每一次颜色与之相同的“合并”中都会将这个点的战力值加上一遍,于是我们可以向上 dp。

我们定义 f_{i, 0/1, j,k} 表示第 i 个点选或不选,它之上要经过 j 次颜色为 1 的合并和 k 次颜色为 0 的合并,一这个点为根的子树的最大士气和。

对于转移,我们很容易想到是由子节点的士气和加起来,对于当前点的点权,我们只需要加上这个点的点权乘上上面这个颜色的合并次数就好了(对于子节点的点权已经包含在子节点的 f 中了)。

抽象来讲,转移方程为:

\begin{cases} f_{x,0,j,k} = \sum\limits_{y\in x} \max\limits_{j = 0}^{j < n}\max\limits_{k = 0}^{k < n}(f_{y,0,j+1,k}, f_{y,1,j,k}) \\ f_{x,1,j,k} = \sum\limits_{y \in x} \max\limits_{j = 0}^{j < n}\max\limits_{k = 0}^{k < n}(f_{y,0,j,k}, f_{y,1,j,k+1}) \end{cases}

注:yx 的子节点,\sum\limits_{y \in x} 表示枚举 x 的子节点。

特别的,对于 f 数组的初值,我们将其定义为这个点的点权对于最终答案的价值,即:

\end{cases}

这样子我们就完成了对于点权的转移计算。

Q.E.D.

最终的时间复杂度是 \mathcal O(n ^ 3)。当然我们也可以通过深度来优化 i,j 的枚举值域,感兴趣的朋友可以自己去研究一下,我不在此做过多讲解。

code

#include <iostream>
#include <cstdio>
#include <vector>
#include <stack>
#include <math.h>
#include <cstring>
#define int long long
using namespace std;
const int N = 3e2 + 5;
int n, u, v, w[N], f[N][2][N][N];
vector <int> e[N];
void dfs(int x, int last)
{
    for(int i = 0;i < n;i++)
        for(int j = 0;j < n;j++)
            f[x][0][i][j] = (i + 1) * w[x], f[x][1][i][j] = (j + 1) * w[x];
    for(int i = 0;i < e[x].size();i++)
    {
        int y = e[x][i];
        if(y == last)
            continue;
        dfs(y, x);
        for(int j = 0;j < n;j++)
        {
            for(int k = 0;k < n;k++)
            {
                f[x][0][j][k] += max(f[y][0][j+1][k], f[y][1][j][k]);
                f[x][1][j][k] += max(f[y][0][j][k], f[y][1][j][k+1]);
            }
        }
    }
    return ;
}
signed main()
{
    scanf("%lld", &n);
    for(int i = 1;i < n;i++)
    {
        scanf("%lld %lld", &u, &v);
        e[u].push_back(v);
        e[v].push_back(u);
    }
    for(int i = 1;i <= n;i++)
        scanf("%lld", &w[i]);
    dfs(1, 0);
    printf("%lld", max(f[1][0][0][0], f[1][1][0][0]));
    return 0;
}