题解:P14945 不想玩原神

· · 题解

你这个有点牛了。

考虑对行分治,列用线段树维护。线段树上每个节点直接维护 bitset。分治的时候我们预处理出中点前的后缀信息和中点后的前缀信息合并即可。

假设我们加入了一行点,最暴力的做法就是把整个线段树重构,但这样复杂度是 \mathcal{O}\left(\frac{n^2}{w}\right)。考虑到分治结构要重构 \mathcal{O}(n\log n) 次,乘起来已经倒闭了。

考虑菜在哪了。我们对每个叶子节点都维护了长度为 n 的 bitset,这显然会导致很多浪费。在靠近叶子的地方,明明每次只有大概 \mathcal{O}(1) 的修改量,却还是要 \mathcal{O}\left(\frac{n}{w}\right) 更新。因此考虑设定一个阈值,阈值之上直接 pushup,阈值之下暴力更新。假设这个阈值为 B,那么插入一行的复杂度为 \mathcal{O}\left(\frac{n^2}{Bw}+n\log B\right),大概取 B=\mathcal{O}\left(\frac{n}{w\log(n/w)}\right) 得到最优复杂度 \mathcal{O}\left(n\log \frac{n}{w}\right)。实际上 B16 很优秀,但是 B=1B=n 都能过,哈哈。

于是我们以 \mathcal{O}\left(n^2\log n\log\frac{n}{w}+\frac{nq\log n}{w}\right) 的复杂度解决了这题。

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef bitset<2000> msg; 

const int MAXN = 2e3 + 10;
const int MAXM = 5e5 + 10;
const int B = 16;

int n, m, a[MAXN][MAXN]; msg t[MAXN << 2];

void clear(int l = 1, int r = n, int p = 1) {
    t[p].reset(); if (l == r) return ; int mid = l + r >> 1;
    clear(l, mid, p << 1), clear(mid + 1, r, p << 1 | 1);
}

void insert(int k, int l = 1, int r = n, int p = 1) {
    if (r - l + 1 <= B) { for (int i = l; i <= r; i++) t[p].set(a[k][i]); }
    if (l == r) return ; int mid = l + r >> 1;
    insert(k, l, mid, p << 1), insert(k, mid + 1, r, p << 1 | 1);
    if (r - l + 1 > B) t[p] = t[p << 1] | t[p << 1 | 1];
}

msg ask(int ql, int qr, int l = 1, int r = n, int p = 1) {
    if (ql <= l && r <= qr) return t[p];
    int mid = l + r >> 1;
    if (qr <= mid) return ask(ql, qr, l, mid, p << 1);
    if (ql > mid) return ask(ql, qr, mid + 1, r, p << 1 | 1);
    return ask(ql, qr, l, mid, p << 1) | ask(ql, qr, mid + 1, r, p << 1 | 1);
}

struct node {
    int al, ar, bl, br, id;
    node(int al = 0, int ar = 0, int bl = 0, int br = 0, int id = 0) : 
        al(al), ar(ar), bl(bl), br(br), id(id) {}
};

struct query {
    int l, r, id;
    query(int l = 0, int r = 0, int id = 0) : l(l), r(r), id(id) {}
}; vector<query> g[MAXN]; msg ans[MAXM];

void solve(int l, int r, const vector<node> &q) {
    vector<node> vl, vr; int mid = l + r >> 1;
    for (node _ : q) {
        if (_.ar < mid) { vl.emplace_back(_); continue; }
        if (_.al > mid) { vr.emplace_back(_); continue; }
        if (_.ar > mid) g[_.ar].emplace_back(_.bl, _.br, _.id);
        if (_.al <= mid) g[_.al].emplace_back(_.bl, _.br, _.id);
    }
    for (int i = mid; i >= l; i--) {
        insert(i);
        for (query _ : g[i]) ans[_.id] |= ask(_.l, _.r);
    }
    clear();
    for (int i = mid + 1; i <= r; i++) {
        insert(i);
        for (query _ : g[i]) ans[_.id] |= ask(_.l, _.r);
    }
    clear();
    for (int i = l; i <= r; i++) g[i].clear();
    if (!vl.empty()) solve(l, mid, vl);
    if (!vr.empty()) solve(mid + 1, r, vr);
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) scanf("%d", &a[i][j]), a[i][j]--;
    }
    vector<node> q; scanf("%d", &m);
    for (int i = 1, al, ar, bl, br; i <= m; i++) {
        scanf("%d%d%d%d", &al, &ar, &bl, &br);
        q.emplace_back(al, ar, bl, br, i);
    }
    solve(1, n, q);
    for (int i = 1; i <= m; i++) printf("%d\n", ans[i].count());
}