题解 P7318 【[JSOI2010] 挖宝藏】

· · 题解

容易发现挖掉某一个宝藏需要挖掉其上方整个倒三角内的所有方块,我们假设这个倒三角在 y=-1 上的左右端点为 l,r,那么这个宝藏在 x 轴上的映射为区间 [l,r]。进一步我们可以发现,宝藏的位置和 x 轴上的区间是一一对应的。

对于一个宝藏,我们可以在 O(1) 求出其在 x 轴上的区间 [l,r] 和单独挖出这一个宝藏的代价 w。如果一个区间被另一个区间完全包含,那么显然选了大的区间就会自动加上小区间的贡献,因此我们可以预处理出每个区间与其完全包含的区间的贡献和,这一部分的时间复杂度为 O(n^2)

之后是一个比较套路的 dp,设 f_i 为前 i 个宝藏中必选 i 的最大利润,从前往后枚举 j

然后你会发现这个 dp 是有问题的,考虑一种情况,假设对于 i 区间现在枚举到 j,而 i 区间与 j 区间的交完全包含了某个区间 k,显然 ij 都完全包含 k,那么 k 的贡献就会被重复计算。因此在 dp 的时候,对于 i,j 有交的情况,我们还需要计算被 i,j 的交所完全包含的区间的贡献和,然后减去这个值。

如果每次 dp 都暴力去找这些区间的话时间复杂度是 O(n ^ 3) 的。考虑优化找这些区间的过程,我们首先对所有区间以 r 为第一关键字,l 为第二关键字排序,那么在枚举 j 的过程中右端点一定单调不降,交的部分右端点也一定单调不降,因此转移的时候直接拿一个指针指向要计算的区间,不断右移指针就可以了。时间复杂度为 O(n^2)

#include <map>
#include <ctime>
#include <stack>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

inline int read() {
    int x = 0, w = 1;char ch = getchar();
    while (ch > '9' || ch < '0') { if (ch == '-')w = -1;ch = getchar(); }
    while (ch >= '0' && ch <= '9')x = x * 10 + ch - '0', ch = getchar();
    return x * w;
}
inline void write(int x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

const int maxn = 1e3 + 5;
const int mod = 1e9 + 7;
const int inf = 1e9;

inline int min(int x, int y) { return x < y ? x : y; }
inline int max(int x, int y) { return x > y ? x : y; }

struct node {
    int l, r, v1, v2, w;
    bool operator < (const node p) const {
        return r == p.r ? l < p.l : r < p.r;
    }
} a[maxn];
int n, ans, cnt, pos, f[maxn];

inline int calc(int l, int r) {
    int delta = (l + r) & 1;
    return (r - l + 2 + delta) * (r - l + 2 - delta) / 4;
}

int main() {
    n = read();
    for (int i = 1, x, y; i <= n; i++) {
        x = read(), y = read();
        a[i].l = x + y + 1, a[i].r = x - y - 1;
        a[i].v1 = read(), a[i].w = calc(a[i].l, a[i].r);
    }
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++) 
            if (a[i].l <= a[j].l && a[j].r <= a[i].r)
                a[i].v2 += a[j].v1;
    sort(a + 1, a + n + 1);
    for (int i = 1, val;i <= n;i++) {
        f[i] = val = a[i].v2 - a[i].w, pos = 1, cnt = 0;
        for (int j = 1, tmp; j < i; j++) {
            if (a[j].r >= a[i].l && a[j].l < a[i].l) {
                while (pos <= i && a[pos].r <= a[j].r) {
                    if (a[pos].l >= a[i].l) cnt += a[pos].v1;
                    pos++;
                }
                tmp = calc(a[i].l, a[j].r);
                f[i] = max(f[i], f[j] + val + tmp - cnt);
            }
            else if (a[j].r < a[i].l) f[i] = max(f[i], f[j] + val);
        }
    }
    for (int i = 0;i <= n;i++) ans = max(f[i], ans);
    printf("%d\n", ans);
    return 0; 
}