P4037 [JSOI2008]魔兽地图 题解

· · 题解

有一定难度的思维题。

输入中的装备从高级向合成它需要的低级装备连边,最后会是一个森林。

对于每个高级装备,它能购买的次数和他的儿子的购买次数有关,所以我们需要在动态规划的维度中加入购买次数。

考虑设 f_{u,i,j} 表示以 u 为根的子树,至少合成 iu 装备,总共在子树中花费 j 金币能得到的最大能力值(如果不合法,即金币不够,f_{u,i,j}=-\inf)。

for(int i = 0; i <= lim[u]; i++)
{
    for(int j = 0; j <= m; j++)
    {
        f[u][i][j] = ((j < i * c[u]) || (i > lim[u]) ? -inf : i * s[u]);
    }
}
f[u][i][j] = max(-inf, f[u][i][j] + f[v][i * w][0] - i * w * s[v]);
for(int k = 1; k <= j; k++)
{
    f[u][i][j] = max(f[u][i][j], f[u][i][j - k] + f[v][i * w][k] - i * w * s[v]);
}
for(int j = 0; j <= m; j++)
{
    f[u][lim[u] + 1][j] = -inf;
    for(int i = lim[u]; i >= 0; i--)
    {
        f[u][i][j] = max(f[u][i + 1][j], f[u][i][j]);
    }
}

理论复杂度:O(n\times \lim \times m^2),计算器敲一下发现好像极限数据是 2.04e10,luogu 交一发 T 到只有40pts,需要优化。

主要是对 dp 转移范围的优化:

朴素的写法是两层循环老实地从 0-m,但是考虑前一层枚举的 i 的意义:u 点买 i 件,那么代价至少是 i \times cost[u] ,同理 v 点买 i \times w 件,代价至少是 i \times w \times cost[v]

然后时间复杂度:O(\texttt{玄学})

Code:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define maxrd 100005
inline char GET_CHAR()
{
    static char buf[maxrd], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, maxrd, stdin), p1 == p2) ? EOF : *p1++;
}
template <class T> inline void read(T &x)
{
    x = 0;
    int f = 0;
    char ch = GET_CHAR();
    while(ch < '0' || ch > '9')
    {
        f |= ch == '-';
        ch = GET_CHAR();
    }
    while(ch >= '0' && ch <= '9')
    {
        x = x * 10 + (ch - 48);
        ch = GET_CHAR();
    }
    x = f ? -x : x;
    return;
}
inline void readch(char &x)
{
    x = GET_CHAR();
    while(x < 'A' && x > 'Z')
    {
        x = GET_CHAR();
    }
    return;
}
#define min(x, y) ((x) < (y) ? x : y)
#define max(x, y) ((x) > (y) ? x : y)
#define inf 0x3f3f3f3f
#define N 55
#define SIZE 100
#define M 2005
int first[N], Next[N], to[N], w[N], tot;
inline void add(int x, int y, int z)
{
    Next[++tot] = first[x];
    first[x] = tot;
    to[tot] = y;
    w[tot] = z;
    return;
}
int n, m;
int s[N], c[N], lim[N], cost[N];
int f[N][SIZE + 5][M], tmp[SIZE + 5][M];
void update(int u, int v, int w)
{
    if(u == n + 1)
    {
        for(int i = m; i >= 0; i--)
        {
            for(int j = 0; j <= i; j++)
            {
                f[u][0][i] = max(f[u][0][i], f[u][0][i - j] + f[v][0][j]);
            }
        }
        return;
    }
    memset(tmp, -0x3f, sizeof(tmp));
    for(int i = 0; i <= lim[u]; i++)
    {
        for(int j = m; j >= cost[u] * i; j--)
        {
            tmp[i][j] = max(tmp[i][j], f[u][i][j] + f[v][i * w][0] - i * w * s[v]);
            for(int k = cost[v] * i * w; k <= j; k++)
            {
                tmp[i][j] = max(tmp[i][j], f[u][i][j - k] + f[v][i * w][k] - i * w * s[v]);
            }
        }
    }
    memcpy(f[u], tmp, sizeof(f[u]));
    return;
}
void dfs(int u)
{
    for(int i = 0; i <= lim[u]; i++)
    {
        for(int j = 0; j <= m; j++)
        {
            f[u][i][j] = ((j < i * c[u]) || (i > lim[u]) ? -inf : i * s[u]);
        }
    }
    for(int i = first[u]; i; i = Next[i])
    {
        int v = to[i];
        dfs(v);
        lim[u] = min(lim[u], lim[v] / w[i]);
        cost[u] += w[i] * cost[v];
        update(u, v, w[i]);
    }
    for(int j = 0; j <= m; j++)
    {
        f[u][lim[u] + 1][j] = -inf;
        for(int i = lim[u]; i >= 0; i--)
        {
            f[u][i][j] = max(f[u][i + 1][j], f[u][i][j]);
        }
    }
    return;
}
int rd[N], root;
signed main()
{
    read(n), read(m);
    char opt;
    int cnt, ch, len;
    for(int i = 1; i <= n; i++)
    {
        read(s[i]), readch(opt);
        if(opt == 'A')
        {
            lim[i] = SIZE;
            read(cnt);
            while(cnt--)
            {
                read(ch), read(len);
                add(i, ch, len);
                rd[ch]++;
            }
        }
        else
        {
            read(c[i]), read(lim[i]);
            cost[i] = c[i];
        }
    }
    for(int i = 1; i <= n; i++)
    {
        if(!c[i] && !rd[i])
        {
            add(n + 1, i, 1);
        }
    }
    dfs(n + 1);
    printf("%d", f[n + 1][0][m]);
    return 0;
}