题解 P4037 【[JSOI2008]魔兽地图】

· · 题解

这道题其实很有难度,是我目前遇到最难的一道树上背包了。
首先认真审题(从这里就开始犯低级错误),发现题目已经告诉我们低级合成高级的路径是一条树。
树上,并且是个背包,这不就是选课这一类树上背包这种题吗?
我们直接把每种合成路径建成一棵树,再把没有合成路径的基本装备当成一棵树,最后都把它们往 0 节点上面连跑树DP。(这都是套路)
也就是说我们可以用选课这道题的思路来解决这一道题,只不过操作起来确实麻烦很多。
首先是预处理,比较简单。你现在看自己的儿子最多能选几个,那么你需要几个该儿子就直接把它最多能选几个除以当前你需要的个数然后所有值取min就好了。
还可以维护当前这个装备需要的花费,放一下代码:

void dfs(int x)
{
    if(!head[x])
    {
        temp[x] = min(temp[x] , m / cost[x]);//temp就是当前装备最多能选几个
        return;
    }
    temp[x] = Inf;
    for(int e = head[x] ; e ; e = edge[e].next)
    {
        int to = edge[e].to;
        dfs(to);
        temp[x] = min(temp[x] , temp[to] / edge[e].num);
        cost[x] += edge[e].num * cost[to];//cost就是当前装备的花费
    }
    temp[x] = min(temp[x] , m / cost[x]);
}

到这里我就开始犯错了,这里首先讲一下我的错误想法:

dp[i][j]表示选到第i个用品此时背包空间大小为j

然后就直接树上DP
然后我就爆了。
很明显我们这么进行状态定义是完全不对的,我们的基本装备是个数量有限的背包啊,那么我们怎么知道我们当前每种一共选了多少个基本装备呢?我们怎么知道当前的转移合不合法?
思来想去完全搞不出来,最后给自己的计时超时了就瞟了瞟题解才知道了正确的DP状态:

dp[i][j][k]表示第i个装备,其中用j个来合成父亲的装备,背包容量为k时所能获得最大收益。

于是这样我们就可以计算了,首先需要分情况讨论当前的点是不是叶子节点,如果是的话我们只能先用一个背包初始化才行:

if(!head[x])
    {
        for(int i = temp[x] ; i >= 0 ; i --)
            for(int j = i ; j <= temp[x] ; j ++)
                dp[x][i][j * cost[x]] = w[x] * (j - i);//有数量限制的dp的模板,在这里不做过多解释
        return;
    }

接着我们就可以跑DP了,第一层我们需要枚举我们选的当前装备个数,第二层就枚举儿子节点,第三四层正常背包就好了。
最后拿出来再处理一遍前面的背包就好了。
讲到这里也就差不多了,如果你还是写挂了可能需要注意一下你的dp数组是否先置为 -Inf ,再其次就是你的DP是否在中间打挂了,不得不说这道题还是有点考细节。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector> 
using namespace std;
const int Len = 105,Inf = 0x3f;
bool flag[Len],flaw[Len];
long long n,m,w[Len],cost[Len],temp[Len],dp[55][105][2005],f[2005];
int head[Len],cnt,fa[Len],root[Len],len;
struct node
{
    int next,to,num;
}edge[Len << 1];
char s[2];
void add(int from,int to,int num)
{
    edge[++ cnt].to = to;
    edge[cnt].next = head[from];
    edge[cnt].num = num;
    head[from] = cnt;
}
void dfs(int x)
{
    if(!head[x])
    {
        temp[x] = min(temp[x] , m / cost[x]);
        return;
    }
    temp[x] = Inf;
    for(int e = head[x] ; e ; e = edge[e].next)
    {
        int to = edge[e].to;
        dfs(to);
        temp[x] = min(temp[x] , temp[to] / edge[e].num);
        cost[x] += edge[e].num * cost[to];
    }
    temp[x] = min(temp[x] , m / cost[x]);
}
void DP(int x)
{
    if(!x)
    {
        for(int e = head[x] ; e ; e = edge[e].next) 
        {
            DP(edge[e].to);
            for(int i = m ; i >= 0 ; i --)
                for(int j = 0 ; j <= i ; j ++)  
                    dp[0][0][i] = max(dp[0][0][i] , dp[0][0][i - j] + dp[edge[e].to][0][j]);
        }
        return;
    }
    if(!head[x])
    {
        for(int i = temp[x] ; i >= 0 ; i --)
            for(int j = i ; j <= temp[x] ; j ++)
                dp[x][i][j * cost[x]] = w[x] * (j - i);
        return;
    }
    for(int e = head[x] ; e ; e = edge[e].next) DP(edge[e].to);
    for(int i = temp[x] ; i >= 0 ; i --)
    {
        memset(f , -Inf , sizeof f);
        f[0] = 0;
        for(int e = head[x] ; e ; e = edge[e].next)
        {
            int to = edge[e].to;
            for(int k = m ; k >= 0 ; k --)
            {
                long long t = -1e9;
                for(int p = 0 ; p <= k ; p ++)
                    t = max(t , f[k - p] + dp[to][i * edge[e].num][p]);
                f[k] = t;
            }
        }
        for(int j = 0 ; j <= i ; j ++)
            for(int k = 0 ; k <= m ; k ++) 
                dp[x][j][k] = max(dp[x][j][k] , f[k] + w[x] * (i - j));
    }
}
int main()
{
    memset(dp , -Inf , sizeof dp);
    scanf("%lld %lld",&n,&m);
    for(int i = 0 ; i <= m ; i ++) dp[0][0][i] = 0;
    for(int i = 1 ; i <= n ; i ++)
    {
        scanf("%lld",&w[i]);
        scanf("%s",s);
        if(s[0] == 'A') 
        {
            int N;
            scanf("%d",&N);
            while(N --)
            {
                int x,y;scanf("%d %d",&x,&y);
                add(i , x , y);
                fa[x] = i;
            }
        }
        else if(s[0] == 'B') scanf("%lld %lld",&cost[i],&temp[i]);
    }
    for(int i = 1 ; i <= n ; i ++) if(!fa[i]) add(0 , i , 1);
    dfs(0);
    DP(0);
    printf("%lld\n",dp[0][0][m]);
    return 0;
}