浅谈用WBLT实现的平衡树

· · 题解

Part 1:算法介绍

顾名思义,WBLT(Weight Balanced Leafy Tree) 是一种把 WBT(Weight Balanced Tree)和 LT(Leafy Tree)杂交形成的数据结构。

那么什么是 WBT 呢?其实它就是在每一个节点上面储存这个节点下子树大小,通过保持左右子树的大小关系在一定范围来保证树高,以达到平衡的目的。

那么什么是 LT 呢?其实它就是维护的原始信息仅存储在树的 叶子节点 上,而非叶子节点仅用于维护子节点信息和维持数据结构的形态的树。广为人知的线段树就是一种 LT。

接下来我们探讨一下 WBLT 的一些性质。

  1. WBLT 的每个节点都要记录四个数据:权值 v、这个节点下子树大小 s、左儿子 l、右儿子 r
  2. 我们规定每个非叶子节点一定有两个子节点,这个节点要维护其子节点信息的合并。
  3. 显然,大部分平衡树每个非叶子节点的右儿子的权值大于等于左儿子的权值。
  4. 在 WBLT 中,还有所有非叶子节点的权值等于右儿子的权值。容易推出,每一个节点的权值就是以之为根的子树中的最大权值。

这样听起来 WBLT 很像所有叶子从左到右递增的一棵维护区间最大值的动态开点线段树了。

接下来我们讨论它的操作。为了方便理解,先给出一些定义。

#define l(x) t[x].l
#define r(x) t[x].r
const int vr(2);//维护平衡的参数,下面再讲用途
struct wbl{
    int l,r,v,s;
}t[8000005];

更新节点编号为 x 的节点的信息的代码。

void pushup(int x){
    //如果没有左儿子那么就是叶子节点
    if (!l(x))    return (void)(t[x].s = 1);//叶子大小设定为1
    //如果x是非叶子节点,维护x的s和v
    t[x].s = t[l(x)].s + t[r(x)].s,t[x].v = t[r(x)].v;
}

a,b 两个节点分别作为左、右儿子连到给定的 x 上的代码。

void merge(int &x,int a,int b){
    t[x = ++ cnt] = {a,b,t[b].v,t[a].s + t[b].s};
  //根据上面给的性质建就行了
}

使树保持平衡的关键代码。

void bal(int x){//检查x的左右子树是否失衡
    if (t[l(x)].s > t[r(x)].s * vr)//如果左子树比右子树大比较多
        merge(r(x),r(l(x)),r(x)),
        l(x) = l(l(x));//这两行为右旋
    else if (t[r(x)].s > t[l(x)].s * vr)//反之
        merge(l(x),l(x),l(r(x))),
        r(x) = r(r(x));//则左旋
}

结合图片理解一下。

那么如何插入一个节点呢?

void ins(int x,int v){//x为当前节点,v为要插入的值
    if (!l(x))//找到叶子了,把这一个节点变成非叶子节点
        t[l(x) = ++ cnt] = {0,0,min(v,t[x].v),1},//在它的左儿子放它原来的值与插入的值中较小的
        t[r(x) = ++ cnt] = {0,0,max(v,t[x].v),1};//在它的右儿子放它原来的值与插入的值中较大的
    else
        ins((t[l(x)].v >= v) ? l(x) : r(x),v);
    //不是叶子继续往下搜,通过将插入值与当前节点值比较来判断应该前往当前节点的左子树还是右子树
    pushup(x),bal(x);//维护数据结构
}

是不是还算比较简单?

删除一个节点的代码和插入的长的差不多。

void del(int x,int v,int fa){//x为当前节点,v为要【数据删除】的值,注意保存一下当前节点的父亲节点
    if (!l(x)){//找到叶子了
        if (t[l(fa)].v == v)         t[fa] = t[r(fa)];
        else if (t[r(fa)].v == v)    t[fa] = t[l(fa)];
    }//可以理解为插入过程的逆过程,即找到与要删除的值权值相等的一个叶子节点,将它和它的父亲节点删除,并用其父亲的另一个儿子代替父亲的位置。
    else
        del((t[l(x)].v >= v) ? l(x) : r(x),v,x),pushup(x),bal(x);
    //和ins一样
}

查找 v 的排名。

int rnk(int x,int v){
    if (!l(x))    return 1;//如果都到叶子节点了这个值在这个子树的排名肯定是1
    return (t[l(x)].v >= v) ? rnk(l(x),v) : rnk(r(x),v) + t[l(x)].s;
    //不然将查询值与当前节点值比较来判断应该前往当前节点的左子树还是右子树,如果是右子树排名加上左子树的大小
}

查询第 s 大的数。

int kth(int x,int s){
    if (t[x].s == s)    return t[x].v;//如果到的这个子树大小与s一样那么第s大的数就是当前节点的权值,因为“每一个节点的权值就是以之为根的子树中的最大权值”。
    return (t[l(x)].s >= s) ? kth(l(x),s) : kth(r(x),s - t[l(x)].s);
    //不然将查询值与当前节点左儿子的s比较来判断应该前往当前节点的左子树还是右子树,如果是右子树要求的子树大小减去左子树的大小
}

然后……就写完了???极短平衡树实现???

Part 2:正确性证明

首先 WBLT 达到理想的满二叉树平衡状态时有 2 \times n - 1 个节点,树高为 \log_2(n) + 1,所以一次操作理论上平均最坏的时间复杂度为 \log_2(n) + 1,约为 \log n

然后注意到 WBLT 维护平衡的方式与 WBT 一样,所以它保持平衡方式的正确性可由 WBT 的得证。其实是蒟蒻太菜了不怎么会证

所以 WBLT 的时间复杂度在本题为 O(n\log n)

WBLT 的优缺点:

优点:实现简单,常数小,支持区间操作,可持久化,可以分裂和合并,总体来说是比较优秀的平衡树。

缺点:需要约两倍的空间,如果不注意垃圾回收,及时回收无用的节点(像我一样)就无法保证空间是线性的。

Part 3:示例代码

Talk ~is ~cheap,show ~me ~the ~code.
#include <bits/stdc++.h>
using namespace std;
#define f(n,m,i) for (register int i(n);i <= m;++ i)
#define nf(n,m,i) for (register int i(n);i >= m;-- i)
#define mf(n,m,i,j) for (register int i(n);i <= m;i += j)
#define nmf(n,m,i,j) for (register int i(n);i >= m;i -= j)
#define dbug(x) cerr << (#x) << ':' << x << ' ';
#define ent cerr << '\n';
#define max(a,b) (((a) > (b)) ? (a) : (b))
#define min(a,b) (((a) < (b)) ? (a) : (b))
#define ll long long
#define gc getchar_unlocked
#define pc putchar_unlocked
#define l(x) t[x].l
#define r(x) t[x].r
int ip(){
    int num(0),fu(1);
    char c(gc());
    while (c < '0' || c > '9'){
        if (c == '-')   fu = -fu;
        c = gc();}
    while (c >= '0' && c <= '9')
        num = num * 10 + c - '0',c = gc();
    return fu * num;
}
void op(int num){
    if (num < 0)    pc('-'),num = -num;
    if (num > 9)    op(num / 10);
    pc(48 ^ (num % 10));
}
int n,opt,rt,cnt;
const int vr(2);
struct wbl{
    int l,r,v,s;
}t[8000005];
void merge(int &x,int a,int b){        t[x = ++ cnt] = {a,b,t[b].v,t[a].s + t[b].s};}
void pushup(int x){
    if (!l(x))    return (void)(t[x].s = 1);
    t[x].s = t[l(x)].s + t[r(x)].s,t[x].v = t[r(x)].v;
}
void bal(int x){
    if (t[l(x)].s > t[r(x)].s * vr)
        merge(r(x),r(l(x)),r(x)),
        l(x) = l(l(x));
    else if (t[r(x)].s > t[l(x)].s * vr)
        merge(l(x),l(x),l(r(x))),
        r(x) = r(r(x));
}
void ins(int x,int v){
    if (!l(x))
        t[l(x) = ++ cnt] = {0,0,min(v,t[x].v),1},
        t[r(x) = ++ cnt] = {0,0,max(v,t[x].v),1};
    else
        ins((t[l(x)].v >= v) ? l(x) : r(x),v);
    pushup(x),bal(x);
}
void del(int x,int v,int fa){
    if (!l(x)){
        if (t[l(fa)].v == v)         t[fa] = t[r(fa)];
        else if (t[r(fa)].v == v)    t[fa] = t[l(fa)];}
    else
        del((t[l(x)].v >= v) ? l(x) : r(x),v,x),pushup(x),bal(x);
}
int rnk(int x,int v){
    if (!l(x))    return 1;
    return (t[l(x)].v >= v) ? rnk(l(x),v) : rnk(r(x),v) + t[l(x)].s;
}
int kth(int x,int s){
    if (t[x].s == s)    return t[x].v;
    return (t[l(x)].s >= s) ? kth(l(x),s) : kth(r(x),s - t[l(x)].s);
}
int main(){
    n = ip(),t[rt = ++ cnt] = {0,0,INT_MAX,1};//建根节点
    f(1,n,i){
        int opt(ip()),x(ip());
        if (opt == 1)         ins(rt,x);
        else if (opt == 2)    del(rt,x,-1);
        else if (opt == 3)    op(rnk(rt,x)),pc('\n');
        else if (opt == 4)    op(kth(rt,x)),pc('\n');
        else if (opt == 5)    op(kth(rt,rnk(rt,x) - 1)),pc('\n');//查前驱相当于查这个值的排名,找第(该排名 - 1)个数。
        else                  op(kth(rt,rnk(rt,x + 1))),pc('\n');//查后继相当于查这个值 + 1的排名,找第该排名个数。注意:可能有多个相等的值,所以要这么写。
    }
    return 0;
}