【笔记】平衡树 学习笔记(上)
二叉搜索树
在学习搜索树之前,这里简要介绍一下二叉搜索树的概念。
二叉搜索树是一种二叉树的树形数据结构,可以被用来维护有序序列。每个结点的儿子最多有两个,有左儿子、右儿子之分。节点上的权值满足左子树内节点的权值严格小于当前节点的权值,当前节点的权值严格小于右子树内节点的权值。按照中序遍历整棵二叉树,输出节点上的权值,可以得到其维护的有序序列。
在这里,我们不允许出现两个相同权值的节点。相同权值的处理办法是,在每个节点上开一个计数器
如上图所示是一棵二叉搜索树。我们将会以它作为例子,讲解二叉搜索树的基本操作。
数组定义
下文代码中,
const int SIZ = 1e5 + 3;
int L[SIZ], R[SIZ], C[SIZ], S[SIZ], W[SIZ];
void pushup(int u){
S[u] = S[L[u]] + C[u] + S[R[u]];
}
int sz;
int newnode(int w){
++ sz;
L[sz] = R[sz] = 0;
C[sz] = S[sz] = 1;
W[sz] = w;
return sz;
}
插入操作
假定现在要往根为
void pushup(int u){
S[u] = S[L[u]] + C[u] + S[R[u]];
}
void insert(int &p, int w){
if(p == 0){
p = newnode(w);
} else {
if(W[p] == w){
++ S[p], ++ C[p];
} else {
if(w < W[p]) insert(L[p], w);
else insert(R[p], w);
pushup(p);
}
}
}
由于插入操作可能会导致该树的根节点发生变化,即如果原来是空树,则树的根节点变成了新建的节点,也就导致
同样地,我们在外部给整棵二叉搜索树执行操作时也是从根节点
此外,将节点插入到当前节点的某棵子树后,子树权值个数会发生变化。当递归过程结束后,进入的那棵子树的权值个数已经正确完成了维护。于是将当前节点
下图展示了向树中插入元素
插入完毕后,从
删除操作
删除操作在普通的二叉搜索树内比较繁琐(不过在某些平衡树内会变得更加简单,也可能会变得极端复杂)。首先显然要递归找到需要删除的节点的编号
- 如果
p 是叶子节点,那么删除后p 子树变为空,同步维护父亲指向它的指针; - 如果
p 是链上节点,那么删除后p 连接着的节点就取代了p 的位置; - 如果
p 既有左儿子又有右儿子,情况则变得比较复杂。我们选择右子树里的最小值,也就是从右儿子开始一直向左走走到的节点。设该节点为q ,将p,q 的权值以及权值个数交换(注意不是直接交换两个节点的位置。交换位置需要修改p,q 的儿子指针,以及它们父亲的儿子指针,讨论量较大)。可以发现,此时二叉树的结构仍然被满足。由于权值交换到了q 上来,也就回到了应该删除的节点在链上的情形,直接删除即可。
这部分比较抽象,我们结合下面删除
首先从根结点出发,找到权值为
接着需要从
我们交换
权值
下文代码实现同样分成两部分组成:
void erase(int &p, int w){
if(W[p] == w){
if(C[p] > 1){
-- C[p], -- S[p]; return;
}
if(L[p] == 0 && R[p] == 0){
p = 0;
} else
if(L[p] == 0 || R[p] == 0){
p = L[p] | R[p];
} else {
replace(R[p], p);
pushup(p);
}
} else {
if(w < W[p]) erase(L[p], w);
else erase(R[p], w);
pushup(p);
}
}
首先找到权值为
void replace(int &o, int p){
if(L[o]){
replace(L[o], p);
pushup(o);
} else {
swap(W[o], W[p]);
swap(C[o], C[p]);
o = R[o];
}
}
一直向左走,直到没有左儿子。交换
在上述函数中,当我们进入了某个节点的某个子树后,由于该子树内的权值个数发生了更新,所以我们从
到这里,有关二叉搜索树的修改操作就已经都完成了。接下来是查询操作。
查询 k 小
假定现在要查询以
- 当
S_{L(p)}\ge k 时,说明左子树内有至少k 个节点比p 要小,所以答案一定在左子树内。递归地查询左子树。 - 否则,当
S_{L(p)}+C_p\ge k 时,说明W_p 就是答案,可以直接返回; - 否则,答案应该在右子树内。此时我们要查询的不是右子树内第
k 小的节点,而是第k-S_{L(p)}-C_p 小。这是因为,前S_{L(p)} 小的元素在左子树内,接下来C_p 小的元素在节点p 里面。
int find_kth(const int p, int k){
int a = S[L[p]], b = a + C[p];
if(a >= k) return find_kth(L[p], k);
if(b >= k) return W[p];
return find_kth(R[p], k - b);
}
查询排名
在这里,权值
- 如果
p=0 ,也就是以p 为根的子树为空,也就没有元素比w 小。答案就是1 。 - 如果
w< W_p ,说明比w 小的元素应该都在p 的左子树内; - 如果
w= W_p ,根据二叉搜索树的性质,比w 小的元素就是左子树内的元素,直接返回S_{L(p)}+1 ; - 如果
w> W_p ,那么p 上的权值以及p 左子树内的权值都要比w 小,所以将答案加上S_{L(p)}+C_p ,此外还要统计右子树内比w 小的元素,所以还要递归处理。
int find_rank(const int p, int k){
if(p == 0) return 1;
if(k < W[p]) return find_rank(L[p], k);
if(k == W[p]) return S[L[p]] + 1;
return S[L[p]] + C[p] + find_rank(R[p], k);
}
查找前驱
我们要找到树上严格比
int find_pre(const int p, int k){
return find_kth(p, find_rank(p, k) - 1);
}
查找后继
我们要找到树上严格比
int find_suc(const int p, int k){
return find_kth(p, find_rank(p, k + 1));
}
完整代码(以模板题为例)
#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
namespace BST{
const int SIZ = 1e5 + 3;
int L[SIZ], R[SIZ], C[SIZ], S[SIZ], W[SIZ];
void pushup(int u){
S[u] = S[L[u]] + C[u] + S[R[u]];
}
int sz;
int newnode(int w){
++ sz;
L[sz] = R[sz] = 0;
C[sz] = S[sz] = 1;
W[sz] = w;
return sz;
}
void insert(int &p, int w){
if(p == 0){
p = newnode(w);
} else {
if(W[p] == w){
++ S[p], ++ C[p];
} else {
if(w < W[p]) insert(L[p], w);
else insert(R[p], w);
pushup(p);
}
}
}
void replace(int &o, int p){
if(L[o]){
replace(L[o], p);
pushup(o);
} else {
swap(W[o], W[p]);
swap(C[o], C[p]);
o = R[o];
}
}
void erase(int &p, int w){
if(W[p] == w){
if(C[p] > 1){
-- C[p], -- S[p]; return;
}
if(L[p] == 0 && R[p] == 0){
p = 0;
} else
if(L[p] == 0 || R[p] == 0){
p = L[p] | R[p];
} else {
replace(R[p], p);
pushup(p);
}
} else {
if(w < W[p]) erase(L[p], w);
else erase(R[p], w);
pushup(p);
}
}
int find_kth(const int p, int k){
int a = S[L[p]], b = a + C[p];
if(a >= k) return find_kth(L[p], k);
if(b >= k) return W[p];
return find_kth(R[p], k - b);
}
int find_rank(const int p, int k){
if(p == 0) return 1;
if(k < W[p]) return find_rank(L[p], k);
if(k == W[p]) return S[L[p]] + 1;
return S[L[p]] + C[p] + find_rank(R[p], k);
}
int find_pre(const int p, int k){
return find_kth(p, find_rank(p, k) - 1);
}
int find_suc(const int p, int k){
return find_kth(p, find_rank(p, k + 1));
}
void dfs(int p){
if(L[p]) dfs(L[p]);
printf("%d, ", W[p]);
if(R[p]) dfs(R[p]);
}
}
int qread(){
int w = 1, c, ret;
while((c = getchar()) > '9' || c < '0') w = (c == '-' ? -1 : 1); ret = c - '0';
while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
return ret * w;
}
int main(){
int n = qread(), r = 0;
up(1, n, i){
int op = qread(), x = qread();
switch(op){
case 1 : BST :: insert(r, x); break;
case 2 : BST :: erase (r, x); break;
case 3 : printf("%d\n", BST :: find_rank(r, x)); break;
case 4 : printf("%d\n", BST :: find_kth (r, x)); break;
case 5 : printf("%d\n", BST :: find_pre (r, x)); break;
case 6 : printf("%d\n", BST :: find_suc (r, x)); break;
}
}
return 0;
}
平衡树简介
考虑上文所实现的二叉搜索树的时间复杂度。记整棵树当前最深的节点的深度为
当整棵树的形态退化成链状,比如从小到大/从大到小依次插入每个元素,那么操作的复杂度将会退化成
可以发现,影响平衡树复杂度的正是整棵树的树高。为了优化复杂度,我们引入了平衡树的概念,即通过一些限制使得整棵二叉树满足某种性质,通常可以约束树高,来将每次操作的最坏复杂度降至
Splay 树
在介绍 Splay 树之前,先对维护二叉查找树平衡的旋转操作进行讲解。
旋转
如图所示,有节点
为了实现旋转操作,显然每个节点除了要维护指向儿子的指针以外,还要设置一个指向父亲的指针。在上述旋转操作中,有这样一些事情发生:
这些是儿子节点发生的变化。此外,
这里指针的赋值操作比较多,也很乱。错误地安排赋值操作的顺序会导致错误。
首先我们修改
- 定义辅助函数
bool is_root(int x)
判断x 是否为根,即x 有无父节点; - 定义辅助函数
bool is_rson(int x)
判断x 是否为右儿子,即检查x 是不是父节点的右儿子。可以发现,如果存在父节点y ,那么X_{y,\mathrm{isrson}(x)}=x 。
先判断
接着处理那个将要被挂在
最后处理
上述过程比较乱,步骤也不唯一。建议自行模拟推导理解整个过程。
void rotate(int x){
int y = F[x], z = F[y];
bool f = is_rson(x);
bool g = is_rson(y);
int &t = X[x][!f];
if(z){ X[z][g] = x; }
if(t){ F[t] = y; }
X[y][f] = t, t = y;
F[y] = x, push_up(y);
F[x] = z, push_up(x);
}
Splay
Splay 操作是 Splay 树的精华。
下图中,我们当前 Splay 的节点是
a 已经是根
无需进行任何操作。
zig / zag
当
zig-zig / zag-zag
当
- 先将
a 的父节点b 向上旋转一级; - 再将
a 向上旋转一级,此时a 取代了c 原本的位置,也就在整个过程中提升了两级。
zig-zag / zag-zig
当
- 先将
a 向上旋转一级; - 再将
a 向上旋转一级,同样地,a 在整个过程中提升了两级,取代了c 的位置。
为什么我们要按照「
时间复杂度
Splay 的时间复杂度是均摊
为了证明 Splay 的复杂度,我们需要引入势能的概念。势能可以描述为与当前状态有关的某个辅助变量。这里举一个有关队列的例子:
维护一个队列
q 。两种操作。
- 操作
1 :弹出队列尾部的k 个元素,保证弹出前队列的长度不小于k ;- 操作
2 :向队列尾部插入一个元素。
直接暴力执行每种操作,那么操作
我们定义这个队列的势能
- 操作
1 复杂度修改为\mathcal O(k)+\Delta \Phi=\mathcal O(k)-k=\mathcal O(0) ; - 操作
2 复杂度修改为\mathcal O(1)+\Delta\Phi=\mathcal O(1)+1=\mathcal O(1) 。
一共
最后我们可以认为操作
为什么可以认为
通过这种方式,我们把复杂度和势能的改变量建立了联系。也就是计算复杂度和势能改变量的和作为某种新的复杂度(由于势能可能减少很多而该操作的复杂度不太大,因此这种和甚至可以是负数)。也就是为下文证明 Splay 复杂度做铺垫。
我们定义
容易发现,对于任意形态的树,它的
接着来分析
zig / zag
子树
zig-zig / zag-zag
由于两者对称,因此仅举 zig-zig 为例。
观察到
注意到这样的性质:
zig-zag / zag-zig
接着是最后部分。
同样地,写出表达式:
依然是考虑用
设
所以,zig-zag / zag-zig 的情形,复杂度可以放缩成
zig / zag 是整个
又因为整个
接着我们将
由于 Splay 树的形态在操作后会发生变化,所以下文不再使用递归写法。同时,由于 Splay 需要维护每个结点的父亲指针,所以在一些具体细节上更加需要注意。
插入
与二叉搜索树的插入基本相同。插入节点后记得设置新的节点的父亲指针。同时,插入完毕后将其
void insert(int &root, int w){
if(root == 0) {root = newnode(w); return;}
int x = root, o = x;
for(;x;o = x, x = X[x][w > W[x]]){
++ S[x]; if(w == W[x]){ ++ C[x], o = x; break;}
}
if(W[o] != w){
if(w < W[o]) X[o][0] = newnode(w), F[sz] = o, o = sz;
else X[o][1] = newnode(w), F[sz] = o, o = sz;
}
splay(root, o);
}
删除
删除操作比二叉搜索树简单一点,但是新增加了一些细节。
首先找到需要删除的节点。将其旋转到树根。
如果其上计数器值大于
- 如果该节点无左右儿子,则直接删除,树根设置为
0 ; - 如果该节点有一个儿子,则将其设置为新的树根。注意,要把它的父指针置为零;
- 否则,情况会变得复杂一点。我们需要合并它的左右儿子。具体做法是,将左子树权值最大的节点旋转到左子树的树根,此时该节点一定没有右儿子(因为它是权值最大的节点,一定不会有节点权值比它大)。然后将其右儿子设置为右子树的树根,同时维护右子树树根的父亲变为它。
注意儿子指针和父亲指针的维护是否正确。不然就会像我一样卡上一年。
void erase(int &root, int w){
int val = S[root];
int x = root, o = x;
for(;x;o = x, x = X[x][w > W[x]]){
-- S[x]; if(w == W[x]){ -- C[x], o = x; break;}
}
splay(root, o);
if(C[o] == 0){
if(X[o][0] == 0 || X[o][1] == 0){
int u = X[o][0] | X[o][1];
if(u != 0) F[root = u] = 0;
} else {
int p = X[o][0]; F[p] = 0;
int q = X[o][0];
while(X[q][1]) q = X[q][1];
splay(p, q);
X[q][1] = X[o][1];
F[X[o][1]] = q;
pushup(q);
root = q;
}
}
}
查询 k 小 / 查询排名 / 查询前驱 / 查询后继
因为是查询操作,不会导致节点的增加或者减少,所以与二叉搜索树基本相同。区别在于查询完毕后要记得
int find_rank(int &root, int w){
int x = root, o = x, a = 0;
for(;x;){
if(w < W[x])
o = x, x = X[x][0];
else {
a += S[X[x][0]];
if(w == W[x]){
o = x; break;
}
a += C[x];
o = x, x = X[x][1];
}
}
splay(root, o); return a + 1;
}
int find_kth(int &root, int w){
int x = root, o = x, a = 0;
for(;x;){
if(w <= S[X[x][0]])
o = x, x = X[x][0];
else {
w -= S[X[x][0]];
if(w <= C[x]){
o = x; break;
}
w -= C[x];
o = x, x = X[x][1];
}
}
splay(root, o); return W[x];
}
int find_pre(int &root, int w){
return find_kth(root, find_rank(root, w) - 1);
}
int find_suc(int &root, int w){
return find_kth(root, find_rank(root, w + 1));
}
参考代码
#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
typedef unsigned int u32;
typedef unsigned long long u64;
namespace Splay{
const int SIZ = 1e6 + 1e5 + 3;
int F[SIZ], C[SIZ], S[SIZ], W[SIZ], X[SIZ][2], sz;
bool is_root(int x){ return F[x] == 0;}
bool is_rson(int x){ return X[F[x]][1] == x;}
int newnode(int w){
W[++ sz] = w, C[sz] = S[sz] = 1, F[sz] = 0;
return sz;
}
void pushup(int x){
S[x] = C[x] + S[X[x][0]] + S[X[x][1]];
}
void rotate(int x){
int y = F[x], z = F[y];
bool f = is_rson(x);
bool g = is_rson(y);
int &t = X[x][!f];
if(z){ X[z][g] = x; }
if(t){ F[t] = y; }
X[y][f] = t, t = y;
F[y] = x, pushup(y);
F[x] = z, pushup(x);
}
void splay(int &root, int x){
for(int f = F[x];f = F[x], f;rotate(x))
if(F[f]) rotate(is_rson(x) == is_rson(f) ? f : x);
root = x;
}
void insert(int &root, int w){
if(root == 0) {root = newnode(w); return;}
int x = root, o = x;
for(;x;o = x, x = X[x][w > W[x]]){
++ S[x]; if(w == W[x]){ ++ C[x], o = x; break;}
}
if(W[o] != w){
if(w < W[o]) X[o][0] = newnode(w), F[sz] = o, o = sz;
else X[o][1] = newnode(w), F[sz] = o, o = sz;
}
splay(root, o);
}
void erase(int &root, int w){
int val = S[root];
int x = root, o = x;
for(;x;o = x, x = X[x][w > W[x]]){
-- S[x]; if(w == W[x]){ -- C[x], o = x; break;}
}
splay(root, o);
if(C[o] == 0){
if(X[o][0] == 0 || X[o][1] == 0){
int u = X[o][0] | X[o][1];
if(u != 0) F[root = u] = 0;
} else {
int p = X[o][0]; F[p] = 0;
int q = X[o][0];
while(X[q][1]) q = X[q][1];
splay(p, q);
X[q][1] = X[o][1];
F[X[o][1]] = q;
pushup(q);
root = q;
}
}
}
int find_rank(int &root, int w){
int x = root, o = x, a = 0;
for(;x;){
if(w < W[x])
o = x, x = X[x][0];
else {
a += S[X[x][0]];
if(w == W[x]){
o = x; break;
}
a += C[x];
o = x, x = X[x][1];
}
}
splay(root, o); return a + 1;
}
int find_kth(int &root, int w){
int x = root, o = x, a = 0;
for(;x;){
if(w <= S[X[x][0]])
o = x, x = X[x][0];
else {
w -= S[X[x][0]];
if(w <= C[x]){
o = x; break;
}
w -= C[x];
o = x, x = X[x][1];
}
}
splay(root, o); return W[x];
}
int find_pre(int &root, int w){
return find_kth(root, find_rank(root, w) - 1);
}
int find_suc(int &root, int w){
return find_kth(root, find_rank(root, w + 1));
}
}
int qread(){
int w=1,c,ret;
while((c = getchar()) > '9' || c < '0') w = (c == '-' ? -1 : 1); ret = c - '0';
while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
return ret * w;
}
int main(){
using namespace Splay;
int n = qread(), m = qread(), root = 0;
up(1, n, i){
int a = qread(); insert(root, a);
}
int last_ans = 0, ans = 0;
up(1, m, i){
int op = qread(), x = qread() ^ last_ans;
switch(op){
case 1 : insert(root, x); break;
case 2 : erase (root, x); break;
case 3 : ans ^= (last_ans = find_rank(root, x)); break;
case 4 : ans ^= (last_ans = find_kth (root, x)); break;
case 5 : ans ^= (last_ans = find_pre (root, x)); break;
case 6 : ans ^= (last_ans = find_suc (root, x)); break;
}
}
printf("%d\n", ans);
return 0;
}