题解 P1174 【打砖块】

· · 题解

查看原题请戳这里

导入

我们先来看一个看似正确的分组背包的方法:

我们将每一列拆分为n个物品,第m行第k个物品的价值是\sum_{i=k}^na[i][m],其代价为\sum_{i=k}^n[pd[i][m]=1]

简单的说,第m行第k个物品的价值是第mk-n个砖块价值的总和,代价为打完这些砖块需要的子弹数。

然后,我们就以此来跑一边分组背包,于是就愉快的暴0了

为什么呢?

因为如果当前一个标记为Y的砖块在最下方,但是我们手中并没有子弹了,那么虽然这个砖块在打前和打后我们拥有的子弹数不变,即代价为0,但我们却已经没有办法去打这个砖块了。

正解

预处理

我们可以发现,如果我们当前有Y在最下方,而我们最新打的一个砖块的标记为N,那么我们完全可以先不打N,而是先打Y,然后再用新获得的子弹去打那个N

由于当某个标记为Y的砖块在最下方时,直接去打掉这个砖块肯定是最优的,所以我们可以贪心地把所有的Y都压在一起。更确切的,我们是把这些标记为Y的砖块压到了这些砖块下方的那个砖块。根据引入中提到的那个问题,由于我们打完N以后可能恰好用完了所有的子弹,所以我们用v[i][j][0]表示第i列用j发子弹且最后一发子弹打到了N上能获得的价值,用v[i][j][1]表示第i列用j发子弹且最后一发子弹打到了Y上时获得的价值。

状态设计

我们用f[i][j][0]表示前i行用j发子弹且最后一发子弹打到了标记为N的砖块能获得的最大价值,f[i][j][1]表示前i行用j发子弹且最后一发子弹打到了标记为Y的砖块能获得的最大价值。

状态转移

先贴一波代码:

for(int i = 1; i <= m; i++)
        for(int j = 0; j <= k; j++)
            for(int l = 0; l <= min(n,j); l++)
            {
                f[i][j][1] = max(f[i][j][1],f[i - 1][j - l][1] + v[i][l][1]);
                if(l) f[i][j][0] = max(f[i][j][0],f[i - 1][j - l][1] + v[i][l][0]);
                if(j > l) f[i][j][0] = max(f[i][j][0],f[i - 1][j - l][0] + v[i][l][1]);
            }

其中i是枚举到了前i列,j是前i列共用了j发子弹,l是第j列用了l发子弹。

f[i][j][1] = max(f[i][j][1],f[i - 1][j - l][1] + v[i][l][1]);

这个转移是说我从1j-1列借一发子弹(从最后一发子弹达到标记为Y的砖块进行转移,这样才能借到剩余的子弹),先用原本分配给这一列的l枚子弹打完所以能打的N,然后再用借来的子弹把所以压缩到这个N上的Y打掉(特殊的,如果这个N后面没有Y,那我就不打,这样无论如何最终我都会剩余一颗子弹没有用)。

if(l) f[i][j][0] = max(f[i][j][0],f[i - 1][j - l][1] + v[i][l][0]);

这个转移是说如果我分配给了第j列了子弹(若l=0,则我并没有消耗子弹去打第j列的砖块,那么这个转移没有意义),我如果从1i-1列借子弹能获得的最大价值。

if(j > l) f[i][j][0] = max(f[i][j][0],f[i - 1][j - l][0] + v[i][l][1]);

这个转移是说如果我让第1i-1列消耗了一定量的子弹,且不从其中某列借子弹,第1i-1列能够获得的最大价值。

注:在前两段中所说的消耗子弹是指打完某些砖块后总子弹数变少,只打标记为Y的砖块不算消耗子弹。

代码

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
#define ll long long
#define INF 0x7fffffff
#define re register

using namespace std;

int read()
{
    register int x = 0,f = 1;register char ch;
    ch = getchar();
    while(ch > '9' || ch < '0'){if(ch == '-') f = -f;ch = getchar();}
    while(ch <= '9' && ch >= '0'){x = x * 10 + ch - 48;ch = getchar();}
    return x * f;
}

int n,m,k,cnt,a[205][205],b[205][205],v[205][205][2],f[205][205][2];

char c;

int main()
{
    n = read(); m = read();k = read();
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= m; j++)
        {
            cin >> a[i][j] >> c;
            if(c == 'Y') b[i][j] = 1;
        }
    for(int i = 1; i <= m; i++)
    {
        cnt = 0;
        for(int j = n; j >= 1; j--)
        {
            if(b[j][i]) v[i][cnt][1] += a[j][i];
            else cnt++,v[i][cnt][1] = v[i][cnt - 1][1] + a[j][i], v[i][cnt][0] = v[i][cnt - 1][1] + a[j][i];
        }
    }
    for(int i = 1; i <= m; i++)
        for(int j = 0; j <= k; j++)
            for(int l = 0; l <= min(n,j); l++)
            {
                f[i][j][1] = max(f[i][j][1],f[i - 1][j - l][1] + v[i][l][1]);
                if(l) f[i][j][0] = max(f[i][j][0],f[i - 1][j - l][1] + v[i][l][0]);
                if(j > l) f[i][j][0] = max(f[i][j][0],f[i - 1][j - l][0] + v[i][l][1]);
            }
    printf("%d\n",f[m][k][0]);
    return 0;
}