树形 DP
概念
树形 DP 是指在树形结构上的 DP 问题。
第一类
兄弟结点之间没有数量上的约束关系。
第二类
兄弟结点之间有数量上的约束关系。
从转移方向分类
-
从上往下转移,先转移,后递归。
-
从下往上转移,先递归,后转移。
一般树形 DP 模板(伪)
void dfs(int cur,int fa){
dp[cur] = 初始状态;//初始化
for(int nxt:g[cur]){
if(nxt == fa)continue;
dfs(nxt,cur);//递归
dp[cur] = ...dp[nxt]...;//转移(决策)
}
}//此为从下往上转移,如果从上往下转移,交换递归与转移顺序即可
例题
一、兄弟结点之间没有数量上的约束关系
P1122 最大子树和
题意
给定
分析
- 最终的子树形态与根节点无关,考虑以点
1 为根节点处理。 - 当根节点确定后,子树的选择影响父节点,但反过来不会,且不同大小的子树,具有相同的结构,考虑 DP。
- 状态:
dp_i 表示以点i 为根节点的子树的最大点权和。 - 答案:
\max(dp_i) 。 - 状态转移:
if(nxt !- fa){ dfs(nxt,cur); dp[cur] = max(dp[cur],dp[cur]+dp[nxt]); } - 初始状态:
dp_{cur} = a_{cur} 。#include <bits/stdc++.h> using namespace std; const int maxn = 16e3+100; vector<int>g[maxn]; int dp[maxn],ans = INT_MIN; void dfs(int cur,int fa){ for(int nxt:g[cur]){ if(nxt == fa)continue; dfs(nxt,cur); dp[cur] = max(dp[cur],dp[cur]+dp[nxt]); } ans = max(ans,dp[cur]); } int main(){ int n; cin>>n; vector<int>f(n+1); for(int i = 1;i<=n;i++)cin>>dp[i]; int x,y; for(int i = 1;i<n;i++){ cin>>x>>y; g[x].push_back(y); g[y].push_back(x); } dfs(1,0); cout<<ans; return 0; }P2016 战略游戏
题意:在一棵无根树上,每个结点都可以放士兵,其相连的边就可以被瞭望到,求所有边都被瞭望到的最少放置士兵数。
- 状态:
dp_{i,0/1} 表示以点i 为根节点的子树且点i 不放或者放士兵覆盖整棵子树的边的最少士兵数。 - 答案:
\min(dp_{root,0},dp_{root,1}) 。 - 状态转移:
dfs(nxt,cur); dp[cur][0]+=dp[nxt][1]; dp[cur][1]+=min(dp[nxt][0],dp[nxt][1]); - 初始状态:
dp_{cur,0} = 0,dp_{cur,1} = 1 ,其余状态极大值。#include <bits/stdc++.h> using namespace std; const int maxn = 2000; vector<int>g[maxn]; int dp[maxn][2]; void dfs(int cur,int fa){ dp[cur][0] = 0; dp[cur][1] = 1; for(int nxt:g[cur]){ if(nxt == fa)continue; dfs(nxt,cur); dp[cur][0] += dp[nxt][1]; dp[cur][1] += min(dp[nxt][0],dp[nxt][1]); } } int main(){ int n; cin>>n; for(int i = 1;i<=n;i++){ int x,y,k; cin>>x>>k; x++; while(k--){ cin>>y; y++; g[x].push_back(y); g[y].push_back(x); } } memset(dp,0x3f,sizeof(dp)); dfs(1,0); cout<<min(dp[1][1],dp[1][0]); return 0; }二、树形背包
P1273 有线电视网
题意:给定
n 个点的树,有边权(成本),叶子结点有点权(收益),求收益不小于成本的情况下,覆盖的最多的叶子节点数。 - 把边权当做负的收益,输入时加负号,则最终要求收益大于等于
0 。 - 状态:
dp_{i,j} 表示以点i 为根的子树覆盖j 个叶子获得的最大收益。 - 答案:
for(int i<=m;i>=0;i--){ if(dp[1][i]>=0){ cout<<i; return 0; } } - 状态转移:
if(nxt == fa)continue; dfs(nxt,cur); for(int j = min(m,siz[cur]);j>=0;j--){//优化(取min) O(n^3)->O(n^2) for(int k = 0;k<=min(j,siz[nxt]);k++){//优化同上 dp[cur][j] = max(dp[cur][j],dp[nxt][k]+dp[cur][j-k]+w); } } -
初始状态:
dp_{cur,0} = 0 ,其余极小值。对于cur > n-m ,则siz_{cur} = 1,dp_{cur,1} = a_{cur} 。#include <bits/stdc++.h> using namespace std; const int maxn = 3100; struct node{ int x,w; }; int dp[maxn][maxn],a[maxn]; vector<node>g[maxn]; int n,m; int dfs(int cur){ dp[cur][0] = 0; if(cur>n-m){ dp[cur][1] = a[cur]; return 1; } int sum = 0; for(auto x:g[cur]){ int nxt = x.x,w = x.w; int tmp = dfs(nxt); sum+=tmp; for(int i = min(m,sum);i>=0;i--){ for(int j = 0;j<=min(i,tmp);j++){ dp[cur][i] = max(dp[cur][i],dp[nxt][j]+dp[cur][i-j]-w); } } } return sum; } int main(){ cin>>n>>m; int k,f,c; memset(dp,-0x3f,sizeof(dp)); for(int i = 1;i<=n-m;i++){ cin>>k; while(k--){ cin>>f>>c; g[i].push_back({f,c}); } } for(int i = n-m+1;i<=n;i++){ cin>>a[i]; } dfs(1); for(int i = m;i>=0;i--){ if(dp[1][i]>=0){ cout<<i; return 0; } } return 0; }P3177 [HAOI2015] 树上染色
题意:给定一棵点数为
n 的树,再给定k ,要求将树上k 个点染黑,其余染白,收益为所有黑点两两之间的距离以及所有白点两两之间的距离之和,求最大化利益是多少。 - 直接枚举染色方案并计算点对的距离时间复杂度太高。
- 距离是由路径构成的,路径是由边构成的,考虑计算每条边在多少点对的路径中。
- 状态:
dp_{i,j} 表示以点i 为根的子树染j 个黑色结点的最大收益。 - 答案:
dp_{1,k} 。 - 状态转移:
for(int i = min(siz[cur],k);i>=0;i--){ for(int j = 0;j<=min(siz[nxt],i);j++){ int black = j*(k-j)*w; int white = (siz[nxt]-j)*(n-k-(siz[nxt]-j))*w; dp[cur][i] = max(dp[cur][j],dp[nxt][k]+dp[cur][i-j]+black+white); } } - 初始状态:
dp_{cur,0} = 0 。#include <bits/stdc++.h> using namespace std; #define int long long const int maxn = 2100; struct node{ int x,w; }; int n,k; int dp[maxn][maxn]; vector<node>g[maxn]; int dfs(int cur,int fa){ int sum = 1; dp[cur][1] = dp[cur][0] = 0; for(node x:g[cur]){ int nxt = x.x,w = x.w; if(nxt == fa)continue; int tmp = dfs(nxt,cur); sum+=tmp; for(int i = min(k,sum);i>=0;i--){ for(int j = 0;j<=min(i,tmp);j++){ int b = j*(k-j)*w; int wr = (tmp-j)*(n-k-(tmp-j))*w; if(dp[cur][i-j] != 0xcfcfcfcfcfcfcfcf){ dp[cur][i] = max(dp[cur][i],dp[nxt][j]+dp[cur][i-j]+b+wr); } } } } return sum; } signed main(){ cin>>n>>k; int x,y,w; memset(dp,0xcf,sizeof(dp)); for(int i = 1;i<n;i++){ cin>>x>>y>>w; g[x].push_back({y,w}); g[y].push_back({x,w}); } dfs(1,0); cout<<dp[1][k]; return 0; }