AT_abc416_f [ABC416F] Paint Tree 2 题解

· · 题解

纪念我的第一道独自做的蓝色树 DP。

发现 K 很小,猜测复杂度大概 O(NK),这启示我们树形 DP。

为了方便状态转移,考虑将每颗子树分三种情况处理:不选根、选根但是不能连到上面、选根并能连到上面。dp_{u,0/1/2,x} 表示以 u 为根,情况为 0/1/2,选中了 x 条链的子树最大答案。

转移很暴力,设点 u 是点 v 父亲。

dp_{u,0,x}\larr\max_{y\le x,0\le t\le 2}\{dp_{u,0,x-y}+dp_{v,t,y}\}\\ dp_{u,1,x}\larr\max_{y\le x,0\le t\le 2}\{dp_{u,1,x-y}+dp_{v,t,y}\}\\ dp_{u,2,x}\larr\max_{y\le x,0\le t\le 2}\{dp_{u,2,x-y}+dp_{v,t,y}\}\\ dp_{u,2,x}\larr\max_{y\le x}\{dp_{u,0,x-y}+dp_{v,2,y}+a_u\}\\ dp_{u,1,x}\larr\max_{y\le x}\{dp_{u,2,x-y+1}+dp_{v,2,y}\}\\ dp_{u,2,x}\larr dp_{u,0,x-1}+a_u

前三个转移表示 u,v 所在链不同,第四个表示 vu 连的第一个儿子,第五个表示 vu 连的第二个儿子,最后一个表示 u 单独成链。

状态 N\times K\times 3=O(NK),转移 O(K),总复杂度 O(NK^2)。 :::success[AC 代码]{open}

#include <bits/stdc++.h>
using namespace std;
using ll= long long;
const int N=200005;
vector<int> g[N];
int n,k,a[N];
ll dp[N][3][6];
ll ry(int u,int x) {
    return max({dp[u][0][x],dp[u][1][x],dp[u][2][x]});
}
void upd(ll& x,ll y) {
    if(y>x) x=y;
}
void dfs(int u,int fa) {
    for(int i=0;i<3;i++)
        for(int j=0;j<=k;j++)
            dp[u][i][j]=-1e16;
    dp[u][0][0]=0;
    ll la[3][6];
    for(int& v: g[u]) if(v!=fa) {
        dfs(v,u);
        for(int i=0;i<3;i++)
            for(int j=0;j<=k;j++)
                la[i][j]=dp[u][i][j];
        for(int x=1;x<=k;x++) {
            for(int y=1;y<=x;y++) {
                upd(dp[u][0][x],la[0][x-y]+ry(v,y));
                upd(dp[u][2][x],la[2][x-y]+ry(v,y));
                upd(dp[u][2][x],la[0][x-y]+dp[v][2][y]+a[u]);
                upd(dp[u][1][x],la[1][x-y]+ry(v,y));
                upd(dp[u][1][x],la[2][x-y+1]+dp[v][2][y]);
            }
        }
    }
    for(int i=1;i<=k;i++)
        upd(dp[u][2][i],dp[u][0][i-1]+a[u]);
}
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    cin>>n>>k;
    for(int i=1;i<=n;i++)
        cin>>a[i];
    for(int u,v,i=1;i<n;i++) {
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1,0);
    ll ans=-1e16;
    for(int i=0;i<=k;i++)
        upd(ans,ry(1,i));
    cout<<ans;
    return 0;
}

:::