P8659 题解 | WBLT 介绍

· · 题解

最优解题解。

这道题是可持久化平衡树的板子题。对于 FHQ Treap,我不过多赘述,大家可以看其他题解。

但是,FHQ Treap 处理区间复制的时候,有一个问题,就是他们的随机权值会被直接复制。可能有几种处理办法:在 merge 过程中动态生成大小关系,以及复制的时候新生成权值。这样做看似十分正确,可其实大部分题解时间复杂度都依赖替罪羊式重构,讲白了,就是拿非旋 Treap 模拟替罪羊树。如果拿引用计数去除空间复杂度的干扰,你就会发现,那个玩意儿跑的贼慢。那替罪羊树式维护可持久化非旋 Treap 时间复杂度应该是正确的,但是依赖重构的东西看起来总是不优美,而且给人一种难以接受的感觉 而且老难写了。所以这里我们介绍一个时间复杂度比非旋 Treap 更优,更稳定,不依赖随机数,实现更简单的算法,WBLT。

什么是 WBLT?

Weight Balanced Leafy Tree,是根据 Weight 平衡的,信息记录在 Leafy Nodes 上的树。讲人话就是一颗平衡线段树。大概结构就是这样:

         [1](w=3,sum=9)
         / \
        /   \
       /     \
     [2]     {5}(w=1,sum=4)
(w=2,sum=5)
    /   \
   /     \
  /      {4}(w=1,sum=2)
 /       
{3}(w=1,sum=3)
  1. 可以看到,其实 w 就代表当前节点子树内有多少个叶子节点。
  2. 我们发现,WBLT 是完整二叉树。(孩子数量非 02
  3. WBLT 的 Non Leafy Node 的左右子树都是 WBLT。
  4. 我们定义一个 WBLT 是 Balanced,当且仅当对于任意一个非叶子节点都有 \dfrac{\min(w_l, w_r)}{w} \ge \alpha。原论文指出 \alpha \in [\dfrac{2}{11}, 1 - \dfrac{\sqrt 2}{2}] 可以保证时间复杂度是正确的,显然 \alpha=\dfrac{1}{4} 是一个再好不过的选择了。
  5. 对于 WBLT 的每一个 Non Leafy Node,无论何时都必须保证它是 Balanced

它支持什么?

  1. 区间更新(加、乘、懒标记等)
  2. 区间查询(和,矩阵乘积等)
  3. 区间复制
  4. 可持久化
  5. 增加/删除一个/一段元素(本文未讲述,留做习题

时间复杂度:\mathcal O(n\log n)
空间复杂度:\mathcal O(n) 可持久化:\mathcal O(n\log n)

可以看出,几乎就是一个时间复杂度正确的,和 FHQ Treap 没什么两样的平衡树,甚至,在可持久化上,比 FHQ Treap 还好写!!!(这才是最重要的)

基础函数

  1. 先定义一个最重要的东西,也是大家从来没有写过的东西:

    inline bool needBalance(int wl, int wr) { return wr * 3 < wl; }

    这个函数就是告诉你 wlwr 重太多了,要减肥平衡了!

  2. 肯定要定义每个节点的信息,对吧! ::::info[struct Node]{open}

    enum {L, R};
    struct Node {
        int ch[2], w, refCnt;
        LL sum, add; // tag
    
        Node& operator()(bool c);
        inline int& operator[](bool c) { return ch[c]; }
    
        inline void pushup() { if (w^1) sum = (*this)(L).sum + (*this)(R).sum, w = (*this)(L).w + (*this)(R).w; }
        inline void pull(LL x) { add += x; sum += x * w; }
    
        Node() : ch{}, w(), refCnt(), sum(), add() {}
        Node(int v) : ch{}, w(1), refCnt(1), sum(v), add() {}
        Node(int l, int r) : ch{l, r}, w(), refCnt(1), sum(), add() { pushup(); }
    } tr[N << 4];
    inline Node& Node::operator()(bool c) { return tr[ch[c]]; }

    :::: 这里,如果构造函数传一个参数,就是建一个 Leafy Node,传两个就是根据给定的 ls, rsNon Leafy Node 并且 pushup写得可能有点抽象

  3. 对于所有平衡树,使用指针显然是更优的选择,但是为了卡常提高速度,我们通常使用数组模拟。所以我们就要手动实现 new/mallocdelete/free。这里我使用 allocrecycle 表示对应函数。

    template<class... Args> inline int alloc(Args... args) {
        int u = bintop ? bin[bintop--] : ++top;
        tr[u] = Node(args...); return u;
    }

    众所周知,WBLT 最核心的地方就在于区间复制时间复杂度是对的,所以 WBLT 优势就在于可持久化。而完成永久区间复制要求我们使用引用计数保证空间复杂度线性。 我们定义 ref 表示引用的次数,那么 ref=0 的时候就可以完全删除了。

    inline void recycle(int u) {
        if (u && --tr[u].refCnt == 0) {
            if (tr[u].w ^ 1) recycle(tr[u][L]), recycle(tr[u][R]);
            bin[++bintop] = u;
        }
    }
  4. 我们在进行平衡操作的时候,使用 merge 来维护平衡。

    具体的,如果对于 u 不平衡,我们会采用类似 Splay 的单旋双旋,但是用合并来实现,具体内容请移步OI Wiki。比如我们就可以通过如图所示的方式进行平衡:

        u              u
       / \            / \
      3   1          /   \
     / \     ===>   2     2
    1   2          / \   / \
       / \        1   1 1   1
      1   1

    (这里的节点上的数字都代表 w
    这只是一种例子,一共 5 种情况都大同小异,大家记住就好,因为我不会证。同时注意,合并操作时间复杂度只与左右树相对大小的对数有关。你可以感性理解为均摊的 \mathcal O(1)。事实上如果使用真的单双旋,合并的时间复杂度就是严格 \mathcal O(1),但是这样时间复杂度也非常优秀,还少写两个函数,实测这题只慢了 0.1s。可能因为是递归,所以常数比较大。看代码吧。 ::::info[merge \mathcal O(\log\dfrac{\max(w_l,w_r)}{\min(w_l,w_r)})]{open}

    inline int merge(int l, int r) {
        if (!l || !r) return l | r;
        if (needBalance(tr[l].w, tr[r].w)) { // l too heavy
            auto [ll, lr] = cut(l);
            if (needBalance(tr[lr].w + tr[r].w, tr[ll].w)) {
                auto [lrl, lrr] = cut(lr);
                return merge(merge(ll, lrl), alloc(lrr, r));
            }
            return merge(ll, merge(lr, r));
        }
        if (needBalance(tr[r].w, tr[l].w)) { // r too heavy
            auto [rl, rr] = cut(r);
            if (needBalance(tr[l].w + tr[rl].w, tr[rr].w)) {
                auto [rll, rlr] = cut(rl);
                return merge(merge(l, rll), merge(rlr, rr));
            }
            return merge(merge(l, rl), rr);
        }
        return alloc(l, r);
    }

    :::: 不知道为什么我把里面的 merge 写成 alloc (也就是变成 {\mathcal O(1)})居然也能过,这只能说明 WBLT 太冷门了,还没有人去卡,说不定 WBLT 也是错的。不过目前暂时还是最快的,无论随机还是现有的构造。

  5. 为了能够实现 split 函数,我们还需要三个辅助函数。 ::::info[checkcut(均摊 \mathcal O(1)) 和 pushdown]{open}

    inline void check(int &u) {
        if (tr[u].refCnt == 1) return;
        --tr[u].refCnt;
        if (tr[u].w != 1) ++tr[u](L).refCnt, ++tr[u](R).refCnt;
        tr[u = alloc(tr[u])].refCnt = 1;
    }
    
    inline void pushdown(int &u) {
        if (tr[u].w == 1 || !tr[u].add) return;
        check(u);
        if (tr[u][L]) check(tr[u][L]), tr[u](L).pull(tr[u].add);
        if (tr[u][R]) check(tr[u][R]), tr[u](R).pull(tr[u].add);
        tr[u].add = 0;
    }
    
    inline std::pair<int,int> cut(int& u) {
        if (tr[u].w == 1) return {0,0};
        pushdown(u);
        ++tr[u](L).refCnt, ++tr[u](R).refCnt;
        std::pair<int,int> ret(tr[u][L], tr[u][R]);
        recycle(u);
        u = 0; // u 没啦!再用,你想 use-after-free 喵?
        return ret;
    }

    :::: 这里我们秉持着可持久化原则,如果我被人用了,那我修改节点信息肯定不能直接修改,要另起炉灶。但是我的两个小弟仍然归我,所以他们都又多了一个个师傅。

    这也是为什么我认为可持久化 WBLT 好写的原因,它的可持久化不需要你考虑在非基类函数中一个节点要怎样复制,你只需要对原版的 WBLT 的这三个函数以及 recycle 略加修改,就可以实现可持久化。别问,问就是不会写可持久化 FHQ Treap。。。

  6. 现在让我们来分割线段树平衡树吧!
    这真的和 FHQ Treap 好像啊! ::::info[split \mathcal O(\log n)]{open}

    inline std::pair<int,int> split(int u, int ord) {
        if (!ord) return {0, u};
        if (ord == tr[u].w) return {u, 0};
        auto [l, r] = cut(u);
        if (ord <= tr[l].w) {
            auto [ll, lr] = split(l, ord);
            return {ll, merge(lr, r)};
        }
        auto [rl, rr] = split(r, ord - tr[l].w);
        return {merge(l, rl), rr};
    }

    :::: (将 u 这颗树的前 ord 个节点分割出来,返回前 ord 个节点的树的根指针 和 其余节点的树的根指针)

  7. (Bonus) 这题用不到,但是你写增删的时候就用得到了。如果一颗树不再平衡(needBalance(wl,wr) || needBalance(wr,wl)),那么我们要做的就是对它的两个子树进行合并(u = merge(u->l, u->r))。代码不放了。

实用函数

  1. ::::info[建树]{open}

    int build(int l, int r) {
        if (l == r) return alloc(a[l]);
        int mid = l + r >> 1;
        return alloc(build(l, mid), build(mid + 1, r));
    }

    :::: (将 a[l] \dots a[r] 建成一颗树,返回根节点指针)

  2. ::::info[区间加]{open}

    #define fetch(L, R) ++L, ++R, ++R; auto [tmp, r] = split(root, R); auto [l, u] = split(tmp, L)
    
    inline void plus(int L, int R, int d) {
        fetch(L, R);
        tr[u].pull(d);
        root = merge(merge(l, u), r);
    }

    :::: 令 \forall i \in [l,r]a_i \gets a_i+d

  3. ::::info[区间复制]{open}

    inline void copy(int dL, int dR, int sL, int sR) {
        int src; { // Copy src
            fetch(sL, sR);
            ++tr[src = u].refCnt;
            root = merge(merge(l, u), r);
        }
        fetch(dL, dR);
        root = merge(merge(l, src), r);
        recycle(u);
    }

    :::: 将区间 [sL, sR] 复制到 [dL, dR]。 注意要先提取出来一个版本,在我的实现中等价于 ++refCnt

好的,这题应该做完了,对吧!有兴趣的可以再做做看 P5586,也是一道类似的题目,双倍经验这一块

::::info[代码] 这个是单双旋版本的,你把递归 merge 替换进去也能过。

#include <stdio.h>
#include <algorithm>
#include <assert.h>
constexpr int N = 100005;
using LL = long long;
int a[N];
namespace WBLT {
enum {L, R};
struct Node {
    int ch[2], w, refCnt;
    LL sum, add; // tag
    Node& operator()(bool c);
    inline int& operator[](bool c) { return ch[c]; }
    inline void pushup() { if (w^1) sum = (*this)(L).sum + (*this)(R).sum, w = (*this)(L).w + (*this)(R).w; }
    inline void pull(LL x) { add += x; sum += x * w; }
    Node() : ch{}, w(), refCnt(), sum(), add() {}
    Node(int v) : ch{}, w(1), refCnt(1), sum(v), add() {}
    Node(int l, int r) : ch{l, r}, w(), refCnt(1), sum(), add() { pushup(); }
} tr[N << 4];
inline Node& Node::operator()(bool c) { return tr[ch[c]]; }

namespace Data {
int bin[N << 4], bintop = 0, top = 0;
namespace Funcs {
template<class... Args> inline int alloc(Args... args) {
    // if (top >= (N << 4)) exit(0);
    int u = bintop ? bin[bintop--] : ++top;
    assert(u);
    tr[u] = Node(args...); return u;
}
inline void recycle(int u) {
    assert(u);
    if (u && --tr[u].refCnt == 0) {
        if (tr[u].w ^ 1) recycle(tr[u][L]), recycle(tr[u][R]);
        bin[++bintop] = u;
    }
}
inline void check(int &u) {
    if (tr[u].refCnt == 1) return;
    --tr[u].refCnt;
    if (tr[u].w != 1) ++tr[u](L).refCnt, ++tr[u](R).refCnt;
    tr[u = alloc(tr[u])].refCnt = 1;
}
} // namespace Funcs
} // namespace Data
using namespace Data::Funcs;

inline void pushdown(int &u) {
    if (tr[u].w == 1 || !tr[u].add) return;
    check(u);
    if (tr[u][L]) check(tr[u][L]), tr[u](L).pull(tr[u].add);
    if (tr[u][R]) check(tr[u][R]), tr[u](R).pull(tr[u].add);
    tr[u].add = 0;
}
inline std::pair<int,int> cut(int& u) {
    assert(u);
    if (tr[u].w == 1) return {0,0};
    pushdown(u);
    ++tr[u](L).refCnt, ++tr[u](R).refCnt;
    recycle(u);
    std::pair<int,int> ret(tr[u][L], tr[u][R]);
    u = 0;
    return ret;
}

[[gnu::always_inline]] inline bool needBalance(int wl, int wr) { return wr * 3 < wl; }
[[gnu::always_inline]] inline bool needDoubleRot(int u, bool x) { return tr[u](x).w * 2 < tr[u](!x).w; }
[[gnu::always_inline]] inline void rotate(int &u, bool x) { // 将 u->x 旋转到 u
    auto [l, r] = cut(u);
    if (x) { // r heavier
        auto [rl, rr] = cut(r);
        u = alloc(alloc(l, rl), rr);
    } else {
        auto [ll, lr] = cut(l);
        u = alloc(ll, alloc(lr, r));
    }
}
/* An Example: 
        u              u
       / \            / \
      3   1          /   \
     / \     ===>   2     2
    1   2          / \   / \
       / \        1   1 1   1
      1   1                           */
[[gnu::always_inline]] inline void balance(int &u) {
    if (tr[u].w == 1) return;
    bool x = tr[u](R).w > tr[u](L).w; // 重儿子编号
    if (!needBalance(tr[u](x).w, tr[u](!x).w)) return;
    if (needDoubleRot(tr[u][x], x)) rotate(tr[u][x], !x); // 先一边倒,然后拉回来
    rotate(u, x);
}
inline int merge(int l, int r) {
    if (!l || !r) return l | r;
    if (needBalance(tr[l].w, tr[r].w)) { // l too heavy
        auto [ll, lr] = cut(l);
        int u = alloc(ll, alloc(lr, r));
        balance(u); return u;
    }
    if (needBalance(tr[r].w, tr[l].w)) { // r too heavy
        auto [rl, rr] = cut(r);
        int u = alloc(alloc(l, rl), rr);
        balance(u); return u;
    }
    return alloc(l, r);
}
inline std::pair<int,int> split(int u, int ord) {
    assert(0 <= ord && ord <= tr[u].w);
    if (!ord) return {0, u};
    if (ord == tr[u].w) return {u, 0};
    auto [l, r] = cut(u);
    if (ord <= tr[l].w) {
        auto [ll, lr] = split(l, ord);
        return {ll, merge(lr, r)};
    }
    auto [rl, rr] = split(r, ord - tr[l].w);
    return {merge(l, rl), rr};
}
int build(int l, int r) {
    if (l == r) return alloc(a[l]);
    int mid = l + r >> 1;
    return alloc(build(l, mid), build(mid + 1, r));
}
int root;

namespace Utils {
inline void init(int n) {
    for (int i = n; i >= 1; i--) a[i+1] = a[i];
    a[0] = a[1] = a[n+2] = a[n+3] = 0;
    root = build(0, n+3); // 前后各两个哨兵
}
#define fetch(L, R) ++L, ++R, ++R; auto [tmp, r] = split(root, R); auto [l, u] = split(tmp, L)
inline void plus(int L, int R, int d) {
    fetch(L, R);
    tr[u].pull(d);
    root = merge(merge(l, u), r);
}
inline void copy(int dL, int dR, int sL, int sR) {
    int src; { // Copy src
        fetch(sL, sR);
        ++tr[src = u].refCnt;
        root = merge(merge(l, u), r);
    }
    fetch(dL, dR);
    root = merge(merge(l, src), r);
    recycle(u);
}
inline LL querySum(int L, int R) {
    fetch(L, R);
    LL res = tr[u].sum;
    root = merge(merge(l, u), r);
    return res;
}
#undef fetch
} // namespace Utils
} // namespace WBLT
using namespace WBLT::Utils;
int main() {
    int n, m;
    scanf("%*d%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    init(n);
    for (int opt, a, b, c, d; m--; ) {
        scanf("%d%d%d", &opt, &a, &b);
        switch (opt) {
        case 1: scanf("%d", &c); plus(a, b, c); break;
        case 2: scanf("%d%d", &c, &d); copy(a, b, c, d); break;
        case 3: printf("%lld\n", querySum(a, b)); break;
        }
    }
    return 0;
}

::::