P3437 Tetris 3D

· · 题解

简要题意

给定一个矩阵,初始时每个位置上的数都是 0,每次查询一个子矩阵的最大值,并将这个子矩阵的所有数修改为最大值加 w,输出最后整个矩阵的最大值。

思路

使用线段树套线段树解决。外层线段树维护行,每个结点维护一棵线段树,一个区间为 [l, r] 的结点维护的线段树为矩阵第 l 行到第 r 行的最大值。因为需要区间修改,还需要维护懒标记。值得注意的是,外层线段树的标记是一棵线段树。

由于外层线段树更新结点信息和下传标记都相当于两棵线段树合并,所以使用传统的方法会 T 飞,因此我们需要使用标记永久化。

标记永久化

我们可以利用一维线段树来帮助理解标记永久化。假如我们需要实现维护序列的区间加、区间求和,我们在更新的过程中便修改经过结点的权值,然后对被完全覆盖的区间打标记;在查询时,我们记录经过结点的标记之和,对于被完全覆盖的区间,这个区间实际的和为记录的和加上区间长度乘标记。对应的代码是这样的(代码中 lr 表示当前结点维护的区间为 [l, r]LR 表示查询/修改的区间为 [L, R]):

void update(int pos, int l, int r, int L, int R, int val) {
    sum[pos] += (R - L + 1) * val;
    if (L == l && R == r) {
        tag[pos] += val;
        return;
    }
    int mid = (l + r) / 2;
    if (L <= mid)
        update(ls[pos], l, mid, L, min(R, mid), val);
    if (R > mid)
        update(rs[pos], mid + 1, r, max(L, mid + 1), R, val);
}
int query(int pos, int l, int r, int L, int R, int val) {
    if (L <= l && R >= r)
        return sum[pos] + (r - l + 1) * val;
    val += tag[pos];
    int mid = (l + r) / 2;
    if (R <= mid)
        return query(ls[pos], l, mid, L, R, val);
    if (L > MID)
        return query(rs[pos], mid + 1, r, L, R, val);
    return query(ls[pos], l, mid, L, R, val) + query(rs[pos], mid + 1, r, L, R, val);
}

对于线段树套线段树的区间永久化,和一维线段树类似,只不过外层维护的最大值和懒标记都是一个一维线段树。

容易看出 经过亿些思考,可以发现有一些信息是不能用标记永久化维护的。因为传统的写法中一个结点维护的信息是通过子区间的信息合并得来,而标记永久化是在更新的时候计算对有关结点信息的改变。如果是随机的区间修改与区间求最大值是不能用标记永久化维护的,因为当一个位置的值缩小时,我们无法在更新的过程中得知经过结点信息应该如何变化。

这道题的特殊之处在于区间修改的值比修改前的最大值要大,因此对经过结点的影响是可以计算的,具体来说就是在更新过程中对结点维护的信息取最大。

代码

#include <iostream>
#define LP pos << 1
#define RP LP | 1
#define MID (l + r >> 1)
#define LS l, MID
#define RS MID + 1, r
#define THIS pos, l, r
using namespace std;
const int MAXN = 2e4 + 5;
int n, m, q, tot, ls[MAXN * 100], rs[MAXN * 100], num[MAXN * 100], tag[MAXN * 100];
class SegmentTree1D {
private:
    int root;
    void update(int &pos, int l, int r, int L, int R, int val) {
        if (!pos)
            pos = ++tot;
        num[pos] = max(num[pos], val);
        if (L <= l && R >= r)
            return tag[pos] = max(tag[pos], val), void(); // 这里
        if (L <= MID)
            update(ls[pos], LS, L, R, val);
        if (R > MID)
            update(rs[pos], RS, L, R, val);
    }
    int query(int pos, int l, int r, int L, int R, int val) {
        if (L <= l && R >= r)
            return max(num[pos], val);
        val = max(val, tag[pos]);
        if (R <= MID)
            return query(ls[pos], LS, L, R, val);
        if (L > MID)
            return query(rs[pos], RS, L, R, val);
        return max(query(ls[pos], LS, L, R, val), query(rs[pos], RS, L, R, val));
    }

public:
    void update(int l, int r, int x) { update(root, 1, m, l, r, x); }
    int query(int l, int r) { return query(root, 1, m, l, r, 0); }
};
class SegmentTree2D {
private:
    SegmentTree1D num[MAXN * 4], tag[MAXN * 4];
    void update(int pos, int l, int r, int L, int R, int a, int b, int val) {
        num[pos].update(a, b, val);
        if (L <= l && R >= r)
            return tag[pos].update(a, b, val);
        if (L <= MID)
            update(LP, LS, L, R, a, b, val);
        if (R > MID)
            update(RP, RS, L, R, a, b, val);
    }
    int query(int pos, int l, int r, int L, int R, int a, int b, int val) {
        if (L <= l && R >= r)
            return max(num[pos].query(a, b), val);
        val = max(val, tag[pos].query(a, b));
        if (R <= MID)
            return query(LP, LS, L, R, a, b, val);
        if (L > MID)
            return query(RP, RS, L, R, a, b, val);
        return max(query(LP, LS, L, R, a, b, val), query(RP, RS, L, R, a, b, val));
    }

public:
    void update(int l, int r, int a, int b, int x) { update(1, 1, n, l, r, a, b, x); }
    int query(int l, int r, int a, int b) { return query(1, 1, n, l, r, a, b, 0); }
} t;
int main() {
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n >> m >> q;
    for (int i = 1, d, s, w, x, y; i <= q; ++i)
        cin >> d >> s >> w >> x >> y, t.update(x + 1, x + d, y + 1, y + s, t.query(x + 1, x + d, y + 1, y + s) + w);
    cout << t.query(1, n, 1, m) << "\n";
}

关于内层线段树的标记为什么是取最大值而不是直接赋值

代码中的注释处是内层线段树的标记,这里必须写成取最大值。对于外层线段树,在更新过程中对经过结点维护的线段树的更新的含义是取最大值,而不是区间修改。同时,因为这道题对于区间修改的特殊性,区间修改相当于区间中的所有数对一个数取最大值。因此内层线段树的标记必须是取最大值。(我们可以类比一下 代码中的注释处上数第二行写成 num[pos] = val; 对于一维线段树的影响,相当于注释处写成赋值对二维线段树的影响)