线段树套树状数组套值域分块套块状链表

· · 算法·理论

一个搞笑的 P4278 做法,感觉休闲娱乐更适合一点。

L 为阈值建立块状链表,块长超过 2L 则分裂。

对块状链表建立索引 BIT,可以 BIT 上二分将下标定位到块内,省去 O(L) 的暴力扫描。分裂块时重构索引 BIT。

用一棵线段树维护块状链表上的每个块,一个想法是对线段树上每个点开一个数组 r_i 表示当前区间中 i 的排名,空间 O\left(\dfrac{nV}{L} \right)。查询考虑对值域二分转为区间查排名,散块暴力整块线段树上查询,O\left((L+\log n) \log V\right)

但是每次修改需要修改线段树上的一条树链,pushup 时对每个 r_i 都是前缀修改,O\left(V \log \dfrac{n}{L}\right) 无法接受。

将所有 r 开成树状数组,也就是使用线段树套权值 BIT。这样多一个 \log,查询 O\left((L+\log n) \log^2 V\right),修改 O(\log V \log n),代码在下面,过不了。

::::info[代码]

constexpr int N = 3.5e4 + 5, M = 7e4 + 5;
constexpr int B = 800;

vector<vector<int>> lists;

struct Index {   // 索引 BIT
    vector<int> bit;
    void clear() {bit.clear();}
    void build() {
        int n = lists.size();
        bit.resize(n, 0);
        for (int i = 1; i < n; i++) {
            bit[i] += lists[i - 1].size();
            if (i + lowbit(i) < n) bit[i + lowbit(i)] += bit[i];
        }
    }
    void add(int x) {
        for (; x < (int) bit.size(); x += lowbit(x)) bit[x]++;
    }
    std::pair<int, int> find(int k) {
        if (k < lists[0].size()) return {0, k};
        if (bit.empty()) build();
        int p = 0, n = bit.size();
        for (int i = lg(n); i >= 0; i--) {
            int x = p + (1 << i);
            if (x < n && k >= bit[x]) k -= bit[x], p = x;
        } return {p, k};
    }
} idx;

// 维护区间 <= x 的数的个数
struct SGT {
#define ls (rt << 1)
#define rs (rt << 1 | 1)
    int rk[4 * (N + M) / B][M];
    void add(int* bit, int x, int c) {
        for (; x < M; x += lowbit(x)) bit[x] += c;
    }
    int ask(const int* bit, int x) const {
        int res = 1;
        for (; x >= 1; x -= lowbit(x)) res += bit[x];
        return res;
    }
    void build(int l = 0, int r = lists.size() - 1, int rt = 1) {
        if (l == r) {
            memset(rk[rt], 0, sizeof(rk[rt]));
            for (int i : lists[l]) add(rk[rt], i + 1, 1);
            return;
        }
        int mid = (l + r) >> 1;
        build(l, mid, ls), build(mid + 1, r, rs);
        for (int i = 1; i < M; i++) rk[rt][i] = rk[ls][i] + rk[rs][i];
    }
    // 第 x 块加入 / 删除 c
    void update(int x, int c, int type, int l = 0, int r = lists.size() - 1, int rt = 1) {
        add(rk[rt], c + 1, type);
        if (l == r) return;
        int mid = (l + r) >> 1;
        if (x <= mid) update(x, c, type, l, mid, ls);
        else update(x, c, type, mid + 1, r, rs);
    }
    // c 在块 [tl, tr] 中的排名
    int ask(int tl, int tr, int c, int l = 0, int r = lists.size() - 1, int rt = 1) const {
        if (tl <= l && r <= tr) return ask(rk[rt], c + 1);
        int mid = (l + r) >> 1, res = 0;
        if (tl <= mid) res += ask(tl, tr, c, l, mid, ls);
        if (tr > mid) res += ask(tl, tr, c, mid + 1, r, rs);
        return res;
    }
} sgt;

int tot;
int kth(int l, int r, int k) {
    auto [bl, pl] = idx.find(l);
    auto [br, pr] = idx.find(r);
    if (bl == br) {
        vector<int> vec(lists[bl].begin() + pl, lists[bl].begin() + pr + 1);
        nth_element(vec.begin(), vec.begin() + k - 1, vec.end());
        return vec[k - 1];
    }
    auto ask = [&](int x) -> int {
        int cnt = 0;
        for (int i = pl; i < (int) lists[bl].size(); i++) cnt += (lists[bl][i] <= x);
        for (int i = 0; i <= pr; i++) cnt += (lists[br][i] <= x);
        if (bl + 1 <= br - 1) cnt += sgt.ask(bl + 1, br - 1, x);
        return cnt;
    };
    int lo = 0, hi = M - 1, res = 0;
    while (lo <= hi) {
        int mid = (lo + hi) >> 1;
        if (ask(mid) >= k) hi = mid - 1, res = mid;
        else lo = mid - 1;
    } return res;
}
void modify(int x, int c) {
    auto [b, p] = idx.find(x);
    sgt.update(b, lists[b][p], -1), sgt.update(b, c, 1);
    lists[b][p] = c;
}
void pushup(int b, int v) {
    tot++;
    if ((int) lists[b].size() < 2 * B) {
        sgt.update(b, v, 1);
        idx.add(b);
    } else {
        vector<int> half(lists[b].begin() + B, lists[b].end());
        lists[b].erase(lists[b].begin() + B, lists[b].end());
        lists.insert(lists.begin() + b + 1, half);
        sgt.build(), idx.clear();
    }
}
void insert(int x, int c) {
    // mdebug(x, tot, c);
    if (x == tot) {
        lists.back().emplace_back(c);
        pushup(lists.size() - 1, c);
    } else {
        auto [b, p] = idx.find(x);
        lists[b].insert(lists[b].begin() + p, c);
        pushup(b, c);
    }
}

int n, a[N], q;
char opt;
void build() {
    tot = n;
    for (int i = 1; i <= n; i += B) lists.emplace_back(a + i, a + min(n + 1, i + B));
    sgt.build();
}

void _main() {
    read(n), read(a + 1, a + n + 1), read(q);
    build();
    for (int l, r, k, last = 0; q--; ) {
        readchar(opt), read(l, r),  l ^= last, r ^= last;
        if (opt == 'Q') read(k), k ^= last, writeln(last = kth(l - 1, r - 1, k));
        else if (opt == 'M') modify(l - 1, r);
        else if (opt == 'I') insert(l - 1, r);
    } FastIO::flush();
} 

::::

树状数组套可持久化权值线段树可以降到 \log^2,或许能过。但我的思路想到值域分块去了。

考虑值域分块。对 rB 个分成一块,用一个有序 vector 在块内维护。

仍然使用线段树套树状数组维护块状链表,但是这里的树状数组用来维护值域分块前缀和。因此本做法可以称作:线段树套树状数组套值域分块套块状链表。

修改时,更新值域分块和线段树,复杂度 O\left(B+\log \dfrac{V}{B}\log \dfrac{n}{L}\right)

查询时,首先对值域分块二分求出目标的值域块,先扫描散块再询问线段树,复杂度 O\left(\left(L+\log \dfrac{n}{L}\log \dfrac{V}{B}\right)\log B\right)。然后再在值域块内值域二分,复杂度 O\left(\left(L+\dfrac{n}{L} \log L\right) \log B\right)。因此查询总复杂度 O\left( \left(L+\dfrac{n}{L} \log L+\log \dfrac{n}{L}\log \dfrac{V}{B}\right) \log B \right)

本题视 n,q,V 同阶,不妨令 B=L,总复杂度简化为

O\left(nB+\dfrac{n}{B} \log B \log \dfrac{n}{B} + n \log B \log^2 \dfrac{n}{B} \right)

同时,空间瓶颈在线段树部分,为 O\left(\dfrac{n^2}{B^2} \right)

我的实现中,将 B64 来平衡时空复杂度,L800 左右以防止线段树重构过于频繁,可以在最慢 300ms 左右通过本题。代码调了很久,因此写的很乱。

::::info[代码]

constexpr int N = 3.5e4 + 5, M = 7e4 + 5;
constexpr int O = 6, B = 1 << O, NB = (M >> O) + 1, L = 789, Y = L << 1;

struct Block {
    vector<int> val;
    vector<int> bv[NB + 1];
    void build() {
        for (auto& b : bv) b.clear();
        for (int v : val) bv[v >> O].emplace_back(v);
        for (auto& b : bv) sort(b.begin(), b.end());
    }
    void del(int v) {
        auto& b = bv[v >> O];
        b.erase(lower_bound(b.begin(), b.end(), v));
    }
    void add(int v) {
        auto& b = bv[v >> O];
        b.insert(lower_bound(b.begin(), b.end(), v), v);
    }
    int ask(int b, int v) const {
        return upper_bound(bv[b].begin(), bv[b].end(), v) - bv[b].begin();
    }
    int size() const {return val.size();}
};
vector<Block> blocks;

struct Index {   // 索引 BIT
    vector<int> bit;
    void clear() {bit.clear();}
    void build() {
        int n = blocks.size();
        bit.resize(n, 0);
        for (int i = 1; i < n; i++) {
            bit[i] += blocks[i - 1].size();
            if (i + lowbit(i) < n) bit[i + lowbit(i)] += bit[i];
        }
    }
    std::pair<int, int> find(int k) {
        k--;
        if (k < blocks[0].size()) return {0, k};
        if (bit.empty()) build();
        int p = 0, n = bit.size();
        for (int i = lg(n); i >= 0; i--) {
            int x = p + (1 << i);
            if (x < n && k >= bit[x]) k -= bit[x], p = x;
        } 
        return {p, k};
    }
} idx;

struct SGT {
#define ls (rt << 1)
#define rs (rt << 1 | 1)
    struct BIT {
        int a[NB + 1];
        void clear() {memset(a, 0, sizeof(a));}
        void add(int x, int c) {
            for (; x <= NB; x += lowbit(x)) a[x] += c;
        }
        int ask(int x) const {
            int res = 0;
            for (; x >= 1; x -= lowbit(x)) res += a[x]; 
            return res;
        }
    } bit[4 * (N + M) / L];
    void build(int l = 1, int r = blocks.size(), int rt = 1) {
        bit[rt].clear();
        if (l == r) {
            for (int b = 0; b <= NB; b++) bit[rt].add(b + 1, blocks[l - 1].bv[b].size());
            return;
        }
        int mid = (l + r) >> 1;
        build(l, mid, ls), build(mid + 1, r, rs);
        #pragma GCC unroll 8
        for (int i = 0; i <= NB; i++) bit[rt].a[i] = bit[ls].a[i] + bit[rs].a[i];
    }
    void update(int x, int b, int c, int l = 1, int r = blocks.size(), int rt = 1) {
        bit[rt].add(b + 1, c);
        if (l == r) return;
        int mid = (l + r) >> 1;
        if (x <= mid) update(x, b, c, l, mid, ls);
        else update(x, b, c, mid + 1, r, rs);
    }
    int ask(int tl, int tr, int b, int l = 1, int r = blocks.size(), int rt = 1) const {
        if (tl > tr) return 0;
        if (tl <= l && r <= tr) return bit[rt].ask(b + 1);
        int mid = (l + r) >> 1, res = 0;
        if (tl <= mid) res += ask(tl, tr, b, l, mid, ls);
        if (tr > mid) res += ask(tl, tr, b, mid + 1, r, rs);
        return res;
    }
} sgt;

void expand(int pos) {
    int len = blocks[pos].size();
    if (len < Y) {
        int n = idx.bit.size();
        for (int i = pos + 1; i < n; i += lowbit(i)) idx.bit[i]++;
    } else {
        int mid = len >> 1;
        Block b;
        b.val.assign(blocks[pos].val.begin() + mid, blocks[pos].val.end());
        blocks[pos].val.resize(mid);
        blocks[pos].build(), b.build();
        blocks.insert(blocks.begin() + pos + 1, b);
        idx.clear(), sgt.build();
    }
}

int tot;
int ask(int l, int r, int v, int b) {
    auto [bl, pl] = idx.find(l);
    auto [br, pr] = idx.find(r);
    auto F = [&](int b, int l, int r) -> int {
        int res = 0;
        #pragma GCC unroll 8
        for (int i = l; i <= r; i++) {
            if (blocks[b].val[i] <= v) res++;
        } return res;
    };
    if (bl == br) return F(bl, pl, pr);
    int res = F(bl, pl, blocks[bl].size() - 1) + F(br, 0, pr);
    if (bl + 1 <= br - 1) {
        if (b != 0) res += sgt.ask(bl + 2, br, b - 1);
        for (int i = bl + 1; i <= br - 1; i++) res += blocks[i].ask(b, v);
    } return res;
}
int kth(int l, int r, int k) {
    auto [bl, pl] = idx.find(l);
    auto [br, pr] = idx.find(r);
    if (bl == br) {
        vector<int> vec(blocks[bl].val.begin() + pl, blocks[bl].val.begin() + pr + 1);
        nth_element(vec.begin(), vec.begin() + k - 1, vec.end());
        return vec[k - 1];
    }
    auto check = [&](int b) -> int {
        int lim = min(M, b * B + B - 1);
        int cnt = 0;
        for (int i = pl; i < blocks[bl].size(); i++) cnt += (blocks[bl].val[i] <= lim);
        for (int i = 0; i <= pr; i++) cnt += (blocks[br].val[i] <= lim);
        return cnt + sgt.ask(bl + 2, br, b);
    };
    auto solve = [&]() -> int {
        int l = 0, r = NB - 1, res = 0;
        while (l <= r) {
            int mid = (l + r) >> 1;
            if (check(mid) >= k) r = mid - 1, res = mid;
            else l = mid + 1;
        } return res;
    };
    int b = solve();
    int lo = b * B, hi = min(M, b * B + B - 1), res = 0;
    while (lo <= hi) {
        int mid = (lo + hi) >> 1;
        if (ask(l, r, mid, b) >= k) hi = mid - 1, res = mid;
        else lo = mid + 1;
    } return res;
}
void modify(int x, int c) {
    auto [b, p] = idx.find(x);
    int old = blocks[b].val[p];
    blocks[b].val[p] = c;
    blocks[b].del(old), blocks[b].add(c);
    sgt.update(b + 1, old >> O, -1), sgt.update(b + 1, c >> O, 1);
}
void insert(int x, int c) {
    if (x == tot + 1) {
        blocks.back().val.emplace_back(c);
        blocks.back().add(c);
        sgt.update(blocks.size(), c >> O, 1);
        expand(blocks.size() - 1);
    } else {
        auto [b, p] = idx.find(x);
        blocks[b].val.insert(blocks[b].val.begin() + p, c);
        blocks[b].add(c);
        sgt.update(b + 1, c >> O, 1);
        expand(b);
    } tot++;
}

int n, a[N], q;
char opt;
void build() {
    tot = n;
    Block cur;
    for (int i = 1; i <= n; i++) {
        if (cur.size() > L) cur.build(), blocks.emplace_back(cur), cur = Block();
        cur.val.emplace_back(a[i]);
    }
    cur.build(), blocks.emplace_back(cur), sgt.build();
}

void _main() {
    read(n), read(a + 1, a + n + 1), read(q);
    build();
    for (int l, r, k, last = 0; q--; ) {
        readchar(opt), read(l, r),  l ^= last, r ^= last;
        if (opt == 'Q') read(k), writeln(last = kth(l, r, k ^ last));
        else if (opt == 'M') modify(l, r);
        else if (opt == 'I') insert(l, r);
    } FastIO::flush();
} 

::::