题解:P14945 不想玩原神
Register_int · · 题解
你这个有点牛了。
考虑对行分治,列用线段树维护。线段树上每个节点直接维护 bitset。分治的时候我们预处理出中点前的后缀信息和中点后的前缀信息合并即可。
假设我们加入了一行点,最暴力的做法就是把整个线段树重构,但这样复杂度是
考虑菜在哪了。我们对每个叶子节点都维护了长度为
于是我们以
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef bitset<2000> msg;
const int MAXN = 2e3 + 10;
const int MAXM = 5e5 + 10;
const int B = 16;
int n, m, a[MAXN][MAXN]; msg t[MAXN << 2];
void clear(int l = 1, int r = n, int p = 1) {
t[p].reset(); if (l == r) return ; int mid = l + r >> 1;
clear(l, mid, p << 1), clear(mid + 1, r, p << 1 | 1);
}
void insert(int k, int l = 1, int r = n, int p = 1) {
if (r - l + 1 <= B) { for (int i = l; i <= r; i++) t[p].set(a[k][i]); }
if (l == r) return ; int mid = l + r >> 1;
insert(k, l, mid, p << 1), insert(k, mid + 1, r, p << 1 | 1);
if (r - l + 1 > B) t[p] = t[p << 1] | t[p << 1 | 1];
}
msg ask(int ql, int qr, int l = 1, int r = n, int p = 1) {
if (ql <= l && r <= qr) return t[p];
int mid = l + r >> 1;
if (qr <= mid) return ask(ql, qr, l, mid, p << 1);
if (ql > mid) return ask(ql, qr, mid + 1, r, p << 1 | 1);
return ask(ql, qr, l, mid, p << 1) | ask(ql, qr, mid + 1, r, p << 1 | 1);
}
struct node {
int al, ar, bl, br, id;
node(int al = 0, int ar = 0, int bl = 0, int br = 0, int id = 0) :
al(al), ar(ar), bl(bl), br(br), id(id) {}
};
struct query {
int l, r, id;
query(int l = 0, int r = 0, int id = 0) : l(l), r(r), id(id) {}
}; vector<query> g[MAXN]; msg ans[MAXM];
void solve(int l, int r, const vector<node> &q) {
vector<node> vl, vr; int mid = l + r >> 1;
for (node _ : q) {
if (_.ar < mid) { vl.emplace_back(_); continue; }
if (_.al > mid) { vr.emplace_back(_); continue; }
if (_.ar > mid) g[_.ar].emplace_back(_.bl, _.br, _.id);
if (_.al <= mid) g[_.al].emplace_back(_.bl, _.br, _.id);
}
for (int i = mid; i >= l; i--) {
insert(i);
for (query _ : g[i]) ans[_.id] |= ask(_.l, _.r);
}
clear();
for (int i = mid + 1; i <= r; i++) {
insert(i);
for (query _ : g[i]) ans[_.id] |= ask(_.l, _.r);
}
clear();
for (int i = l; i <= r; i++) g[i].clear();
if (!vl.empty()) solve(l, mid, vl);
if (!vr.empty()) solve(mid + 1, r, vr);
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) scanf("%d", &a[i][j]), a[i][j]--;
}
vector<node> q; scanf("%d", &m);
for (int i = 1, al, ar, bl, br; i <= m; i++) {
scanf("%d%d%d%d", &al, &ar, &bl, &br);
q.emplace_back(al, ar, bl, br, i);
}
solve(1, n, q);
for (int i = 1; i <= m; i++) printf("%d\n", ans[i].count());
}