P9111 [福建省队集训2019] 最大权独立集问题

· · 题解

考虑把贡献拆开,即计算每个点对答案的贡献。

设删除点 u 的时间是 t_u。对于一条连接 u,v 的边,若 t_u < t_v,我们把它定向为 u \to v,否则定向为 v \to u。容易发现,这样定向之后,d_u 会流向所有其可达的点。因此可以对题意做如下转化:给树上的每条边定向,设点 i 可以到达包括 i 在内的 k_i 个点,最大化 \sum\limits_{i=1}^n k_i \times d_i 的值。

考虑树形 DP。我们发现,对于当前 DP 的点 u 和它的一个儿子 v 而言,如果定向为 u \to v,额外的贡献是好计算的,我们只需要在状态中记录下 v 可达的点数。但定向为 v \to u 时,事情就没有那么简单了:v 额外的贡献取决于 u 能够到达多少个点,但在合并完其它子树之前,我们其实并不知道这个确切的值。

注意到,本题的时间复杂度允许我们在背包之外再乘一个 \mathcal{O}(n),我们不妨枚举这个值为 k 并将其计入状态,即假设合并完之后 u 能够到达 k 个点,用这个 k 来提前计算 v 的额外贡献。具体来说,我们设 f_{u,i,k} 为考虑完 u 的子树,u 可达 i 个点,v 的额外贡献按照 k 来计算时,子树内答案的最大值。初始值 f_{u,1,k} = d_u,转移枚举 k,考虑边 (u,v) 的方向:

转移完之后我们需要补上 u 的额外贡献,即 f_{u,i,k} + d_u \times (k-i) \to f_{u,i,k}。最后的答案即为 \max f_{1,i,i}

总时间复杂度 \mathcal{O}(n^3)。代码实现细节略有不同。

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
constexpr int N = 4e2 + 5;
int n, a[N], c[N], sz[N];
LL f[N][N][N], g[N][N][N];
vector <int> e[N];
void dfs(int u) {
    sz[u] = 1;
    for (int k = 1; k <= n; k++) for (int i = 1; i <= n; i++) f[u][i][k] = -1e18;  
    for (int k = 1; k <= n; k++) f[u][1][k] = 1LL * a[u] * k;
    for (auto v : e[u]) {
        dfs(v);
        static LL tmp[N][N];
        for (int k = 1; k <= n; k++) for (int i = 1; i <= sz[u]; i++) tmp[i][k] = f[u][i][k], f[u][i][k] = -1e18;
        for (int k = 1; k <= n; k++) {
            for (int i = 1; i <= sz[u]; i++) {
                for (int j = 1; j <= sz[v]; j++) {
                    f[u][i + j][k] = max(f[u][i + j][k], tmp[i][k] + f[v][j][j]);
                }
                LL val = -1e18;
                for (int j = 1; j <= min(sz[v], n - k); j++) val = max(val, f[v][j][j + k]);
                f[u][i][k] = max(f[u][i][k], tmp[i][k] + val);  
            }
        } 
        sz[u] += sz[v];
    }
}
signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 2; i <= n; i++) {
        cin >> c[i];
        e[c[i]].push_back(i);
    }
    dfs(1);
    LL ans = -1e18;
    for (int i = 1; i <= n; i++) ans = max(ans, f[1][i][i]);
    cout << ans << "\n";
    return 0;
}