题解:P10817 [EC Final 2020] Rectangle Flip 2

· · 题解

基本思路

首先我们应该都做过一种类型题,给定一个一维数组,求一段区间内的最小值;同理我们不难想到另一种题目,求一个二维数组的矩形区域内最小值,洛谷有但是我忘了是哪题

因此这道题我们可以想到,一个矩形被破坏的时间仅取决于最先碎掉的单元格,因此我们对于存在的所有矩形可以做一个统计,统计每个矩形内最先被破坏的单元格所用的时间,而时间 t 的取值范围又是 [1,25000],很方便桶排计数。这样每次查询就是 O(1) 的时间复杂度,我们只需要解决预处理的问题就行。

预处理

但是问题没有完全解决,毕竟总矩形的数量是无法在时限内统计的,因此我们需要想一种区间方法,来一次性批量统计。

这样我们不妨固定矩形上边界为 a 下边界为 b,在这个区间内,我们仅考虑以 ab 为上下边界的矩形,然后我们可以任意的选取左右的边界,这样的话可以选出 \frac{m\cdot(m+1)}{2} 种可能。之后我们如何实现批量处理呢,其实这些所有的矩形里,最小值只有 m,也就是从 1m 每个在区间 (a,b) 的最小值。因此对于这 m 个最小值,我们可以用单调栈维护两个数组,一个是 pre[j] 表示左侧第一个比 j 这个位置最小值更小的值的位置,nt[i] 就表示右侧的。

然后对于每个最小值,所对应的矩形数量就是 (j-pre[j])\cdot(nt[j]-j)

代码

#include <iostream>
#include <stack>
using namespace std;
typedef pair<int, int>PII;
const int N = 501, M = 250001;

int lg[N];
int bt[N][N];
int minv[N][N][10];
long long cnt[M];
int mv[N], pre[N], nxt[N];

int query(int col, int l, int r) {
    int k = lg[r - l + 1];
    return min(minv[l][col][k], minv[r - (1 << k) + 1][col][k]);
}

int main(void) {

    long long n = 0, m = 0;
    cin >> n >> m;
    lg[0] = -1;
    for (int i = 1; i <= n; i++) lg[i] = lg[i >> 1] + 1;

    for (int k = 1; k <= n * m; k++) {
        int x, y;
        cin >> x >> y;
        bt[x][y] = k;
        minv[x][y][0] = k;
    }

    for (int len = 1; len <= lg[n]; len++) {
        for (int i = 1; i + (1 << len) - 1 <= n; i++) {
            for (int j = 1; j <= m; j++) {
                minv[i][j][len] = min(minv[i][j][len - 1],
                    minv[i + (1 << (len - 1))][j][len - 1]);
            }
        }
    }
    for (int top = 1; top <= n; top++) {
        for (int bottom = top; bottom <= n; bottom++) {

            for (int col = 1; col <= m; col++) mv[col] = query(col, top, bottom);

            stack<PII>stk;
            for (int i = 1; i <= m; i++) {
                while (!stk.empty()) {
                    auto t = stk.top();
                    if (t.first > mv[i]) {
                        nxt[t.second] = i;
                        stk.pop();
                    }
                    else {
                        break;
                    }
                }

                if (stk.empty()) {
                    pre[i] = 0;
                }
                else {
                    pre[i] = stk.top().second;
                }
                stk.push(make_pair(mv[i], i));
            }
            while (!stk.empty()) {
                nxt[stk.top().second] = m + 1;
                stk.pop();
            }

            for (int i = 1; i <= m; i++) {
                cnt[mv[i]] += ((i - pre[i]) * (nxt[i] - i));
            }
        }
    }
    long long ans = (n + 1) * n * (m + 1) * m / 4;
    for (int i = 1; i <= n * m; i++) {
        ans -= cnt[i];
        cout << ans << endl;
    }
    return 0;
}