线段树套树状数组套值域分块套块状链表
stripe_python · · 算法·理论
一个搞笑的 P4278 做法,感觉休闲娱乐更适合一点。
以
对块状链表建立索引 BIT,可以 BIT 上二分将下标定位到块内,省去
用一棵线段树维护块状链表上的每个块,一个想法是对线段树上每个点开一个数组
但是每次修改需要修改线段树上的一条树链,pushup 时对每个
将所有
::::info[代码]
constexpr int N = 3.5e4 + 5, M = 7e4 + 5;
constexpr int B = 800;
vector<vector<int>> lists;
struct Index { // 索引 BIT
vector<int> bit;
void clear() {bit.clear();}
void build() {
int n = lists.size();
bit.resize(n, 0);
for (int i = 1; i < n; i++) {
bit[i] += lists[i - 1].size();
if (i + lowbit(i) < n) bit[i + lowbit(i)] += bit[i];
}
}
void add(int x) {
for (; x < (int) bit.size(); x += lowbit(x)) bit[x]++;
}
std::pair<int, int> find(int k) {
if (k < lists[0].size()) return {0, k};
if (bit.empty()) build();
int p = 0, n = bit.size();
for (int i = lg(n); i >= 0; i--) {
int x = p + (1 << i);
if (x < n && k >= bit[x]) k -= bit[x], p = x;
} return {p, k};
}
} idx;
// 维护区间 <= x 的数的个数
struct SGT {
#define ls (rt << 1)
#define rs (rt << 1 | 1)
int rk[4 * (N + M) / B][M];
void add(int* bit, int x, int c) {
for (; x < M; x += lowbit(x)) bit[x] += c;
}
int ask(const int* bit, int x) const {
int res = 1;
for (; x >= 1; x -= lowbit(x)) res += bit[x];
return res;
}
void build(int l = 0, int r = lists.size() - 1, int rt = 1) {
if (l == r) {
memset(rk[rt], 0, sizeof(rk[rt]));
for (int i : lists[l]) add(rk[rt], i + 1, 1);
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls), build(mid + 1, r, rs);
for (int i = 1; i < M; i++) rk[rt][i] = rk[ls][i] + rk[rs][i];
}
// 第 x 块加入 / 删除 c
void update(int x, int c, int type, int l = 0, int r = lists.size() - 1, int rt = 1) {
add(rk[rt], c + 1, type);
if (l == r) return;
int mid = (l + r) >> 1;
if (x <= mid) update(x, c, type, l, mid, ls);
else update(x, c, type, mid + 1, r, rs);
}
// c 在块 [tl, tr] 中的排名
int ask(int tl, int tr, int c, int l = 0, int r = lists.size() - 1, int rt = 1) const {
if (tl <= l && r <= tr) return ask(rk[rt], c + 1);
int mid = (l + r) >> 1, res = 0;
if (tl <= mid) res += ask(tl, tr, c, l, mid, ls);
if (tr > mid) res += ask(tl, tr, c, mid + 1, r, rs);
return res;
}
} sgt;
int tot;
int kth(int l, int r, int k) {
auto [bl, pl] = idx.find(l);
auto [br, pr] = idx.find(r);
if (bl == br) {
vector<int> vec(lists[bl].begin() + pl, lists[bl].begin() + pr + 1);
nth_element(vec.begin(), vec.begin() + k - 1, vec.end());
return vec[k - 1];
}
auto ask = [&](int x) -> int {
int cnt = 0;
for (int i = pl; i < (int) lists[bl].size(); i++) cnt += (lists[bl][i] <= x);
for (int i = 0; i <= pr; i++) cnt += (lists[br][i] <= x);
if (bl + 1 <= br - 1) cnt += sgt.ask(bl + 1, br - 1, x);
return cnt;
};
int lo = 0, hi = M - 1, res = 0;
while (lo <= hi) {
int mid = (lo + hi) >> 1;
if (ask(mid) >= k) hi = mid - 1, res = mid;
else lo = mid - 1;
} return res;
}
void modify(int x, int c) {
auto [b, p] = idx.find(x);
sgt.update(b, lists[b][p], -1), sgt.update(b, c, 1);
lists[b][p] = c;
}
void pushup(int b, int v) {
tot++;
if ((int) lists[b].size() < 2 * B) {
sgt.update(b, v, 1);
idx.add(b);
} else {
vector<int> half(lists[b].begin() + B, lists[b].end());
lists[b].erase(lists[b].begin() + B, lists[b].end());
lists.insert(lists.begin() + b + 1, half);
sgt.build(), idx.clear();
}
}
void insert(int x, int c) {
// mdebug(x, tot, c);
if (x == tot) {
lists.back().emplace_back(c);
pushup(lists.size() - 1, c);
} else {
auto [b, p] = idx.find(x);
lists[b].insert(lists[b].begin() + p, c);
pushup(b, c);
}
}
int n, a[N], q;
char opt;
void build() {
tot = n;
for (int i = 1; i <= n; i += B) lists.emplace_back(a + i, a + min(n + 1, i + B));
sgt.build();
}
void _main() {
read(n), read(a + 1, a + n + 1), read(q);
build();
for (int l, r, k, last = 0; q--; ) {
readchar(opt), read(l, r), l ^= last, r ^= last;
if (opt == 'Q') read(k), k ^= last, writeln(last = kth(l - 1, r - 1, k));
else if (opt == 'M') modify(l - 1, r);
else if (opt == 'I') insert(l - 1, r);
} FastIO::flush();
}
::::
树状数组套可持久化权值线段树可以降到
考虑值域分块。对
仍然使用线段树套树状数组维护块状链表,但是这里的树状数组用来维护值域分块前缀和。因此本做法可以称作:线段树套树状数组套值域分块套块状链表。
修改时,更新值域分块和线段树,复杂度
查询时,首先对值域分块二分求出目标的值域块,先扫描散块再询问线段树,复杂度
本题视
同时,空间瓶颈在线段树部分,为
- 若取
B=O(\sqrt{n}) ,得到时间复杂度O(n \sqrt {n}+n \log^3 n) ,后一项带\dfrac{1}{8} 常数,空间复杂度O(n) 。不到2 \times 10^7 ,可以无压力通过本题。 - 若取
B=O(n^{1/3}) ,得到时间复杂度O(n^{4/3}+n \log^ 3 n) ,后一项带\dfrac{4}{27} 倍常数,空间O(n^{4/3}) 。这应该是理论最优复杂度。 - 若取
B=O(1) ,时间O(n \log^ 2 n) ,空间O(n^2) 。
我的实现中,将
::::info[代码]
constexpr int N = 3.5e4 + 5, M = 7e4 + 5;
constexpr int O = 6, B = 1 << O, NB = (M >> O) + 1, L = 789, Y = L << 1;
struct Block {
vector<int> val;
vector<int> bv[NB + 1];
void build() {
for (auto& b : bv) b.clear();
for (int v : val) bv[v >> O].emplace_back(v);
for (auto& b : bv) sort(b.begin(), b.end());
}
void del(int v) {
auto& b = bv[v >> O];
b.erase(lower_bound(b.begin(), b.end(), v));
}
void add(int v) {
auto& b = bv[v >> O];
b.insert(lower_bound(b.begin(), b.end(), v), v);
}
int ask(int b, int v) const {
return upper_bound(bv[b].begin(), bv[b].end(), v) - bv[b].begin();
}
int size() const {return val.size();}
};
vector<Block> blocks;
struct Index { // 索引 BIT
vector<int> bit;
void clear() {bit.clear();}
void build() {
int n = blocks.size();
bit.resize(n, 0);
for (int i = 1; i < n; i++) {
bit[i] += blocks[i - 1].size();
if (i + lowbit(i) < n) bit[i + lowbit(i)] += bit[i];
}
}
std::pair<int, int> find(int k) {
k--;
if (k < blocks[0].size()) return {0, k};
if (bit.empty()) build();
int p = 0, n = bit.size();
for (int i = lg(n); i >= 0; i--) {
int x = p + (1 << i);
if (x < n && k >= bit[x]) k -= bit[x], p = x;
}
return {p, k};
}
} idx;
struct SGT {
#define ls (rt << 1)
#define rs (rt << 1 | 1)
struct BIT {
int a[NB + 1];
void clear() {memset(a, 0, sizeof(a));}
void add(int x, int c) {
for (; x <= NB; x += lowbit(x)) a[x] += c;
}
int ask(int x) const {
int res = 0;
for (; x >= 1; x -= lowbit(x)) res += a[x];
return res;
}
} bit[4 * (N + M) / L];
void build(int l = 1, int r = blocks.size(), int rt = 1) {
bit[rt].clear();
if (l == r) {
for (int b = 0; b <= NB; b++) bit[rt].add(b + 1, blocks[l - 1].bv[b].size());
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls), build(mid + 1, r, rs);
#pragma GCC unroll 8
for (int i = 0; i <= NB; i++) bit[rt].a[i] = bit[ls].a[i] + bit[rs].a[i];
}
void update(int x, int b, int c, int l = 1, int r = blocks.size(), int rt = 1) {
bit[rt].add(b + 1, c);
if (l == r) return;
int mid = (l + r) >> 1;
if (x <= mid) update(x, b, c, l, mid, ls);
else update(x, b, c, mid + 1, r, rs);
}
int ask(int tl, int tr, int b, int l = 1, int r = blocks.size(), int rt = 1) const {
if (tl > tr) return 0;
if (tl <= l && r <= tr) return bit[rt].ask(b + 1);
int mid = (l + r) >> 1, res = 0;
if (tl <= mid) res += ask(tl, tr, b, l, mid, ls);
if (tr > mid) res += ask(tl, tr, b, mid + 1, r, rs);
return res;
}
} sgt;
void expand(int pos) {
int len = blocks[pos].size();
if (len < Y) {
int n = idx.bit.size();
for (int i = pos + 1; i < n; i += lowbit(i)) idx.bit[i]++;
} else {
int mid = len >> 1;
Block b;
b.val.assign(blocks[pos].val.begin() + mid, blocks[pos].val.end());
blocks[pos].val.resize(mid);
blocks[pos].build(), b.build();
blocks.insert(blocks.begin() + pos + 1, b);
idx.clear(), sgt.build();
}
}
int tot;
int ask(int l, int r, int v, int b) {
auto [bl, pl] = idx.find(l);
auto [br, pr] = idx.find(r);
auto F = [&](int b, int l, int r) -> int {
int res = 0;
#pragma GCC unroll 8
for (int i = l; i <= r; i++) {
if (blocks[b].val[i] <= v) res++;
} return res;
};
if (bl == br) return F(bl, pl, pr);
int res = F(bl, pl, blocks[bl].size() - 1) + F(br, 0, pr);
if (bl + 1 <= br - 1) {
if (b != 0) res += sgt.ask(bl + 2, br, b - 1);
for (int i = bl + 1; i <= br - 1; i++) res += blocks[i].ask(b, v);
} return res;
}
int kth(int l, int r, int k) {
auto [bl, pl] = idx.find(l);
auto [br, pr] = idx.find(r);
if (bl == br) {
vector<int> vec(blocks[bl].val.begin() + pl, blocks[bl].val.begin() + pr + 1);
nth_element(vec.begin(), vec.begin() + k - 1, vec.end());
return vec[k - 1];
}
auto check = [&](int b) -> int {
int lim = min(M, b * B + B - 1);
int cnt = 0;
for (int i = pl; i < blocks[bl].size(); i++) cnt += (blocks[bl].val[i] <= lim);
for (int i = 0; i <= pr; i++) cnt += (blocks[br].val[i] <= lim);
return cnt + sgt.ask(bl + 2, br, b);
};
auto solve = [&]() -> int {
int l = 0, r = NB - 1, res = 0;
while (l <= r) {
int mid = (l + r) >> 1;
if (check(mid) >= k) r = mid - 1, res = mid;
else l = mid + 1;
} return res;
};
int b = solve();
int lo = b * B, hi = min(M, b * B + B - 1), res = 0;
while (lo <= hi) {
int mid = (lo + hi) >> 1;
if (ask(l, r, mid, b) >= k) hi = mid - 1, res = mid;
else lo = mid + 1;
} return res;
}
void modify(int x, int c) {
auto [b, p] = idx.find(x);
int old = blocks[b].val[p];
blocks[b].val[p] = c;
blocks[b].del(old), blocks[b].add(c);
sgt.update(b + 1, old >> O, -1), sgt.update(b + 1, c >> O, 1);
}
void insert(int x, int c) {
if (x == tot + 1) {
blocks.back().val.emplace_back(c);
blocks.back().add(c);
sgt.update(blocks.size(), c >> O, 1);
expand(blocks.size() - 1);
} else {
auto [b, p] = idx.find(x);
blocks[b].val.insert(blocks[b].val.begin() + p, c);
blocks[b].add(c);
sgt.update(b + 1, c >> O, 1);
expand(b);
} tot++;
}
int n, a[N], q;
char opt;
void build() {
tot = n;
Block cur;
for (int i = 1; i <= n; i++) {
if (cur.size() > L) cur.build(), blocks.emplace_back(cur), cur = Block();
cur.val.emplace_back(a[i]);
}
cur.build(), blocks.emplace_back(cur), sgt.build();
}
void _main() {
read(n), read(a + 1, a + n + 1), read(q);
build();
for (int l, r, k, last = 0; q--; ) {
readchar(opt), read(l, r), l ^= last, r ^= last;
if (opt == 'Q') read(k), writeln(last = kth(l, r, k ^ last));
else if (opt == 'M') modify(l, r);
else if (opt == 'I') insert(l, r);
} FastIO::flush();
}
::::