可持久化 AVL
stripe_python · · 题解
现有的在线题解:FHQ-Treap、01-Trie、WBLT、线段树、替罪羊(复杂度是假的),和 有旋 Treap。
那么,既然有旋 Treap 可持久化,同样使用旋转维护平衡的 AVL 树也可以可持久化。
先介绍 AVL 树的原理:AVL 树每个节点维护树高
AVL 树的节点定义如下:
template <class T>
struct AVLTreeNode {
T val;
AVLTreeNode<T> *left, *right;
long cnt, size, high;
AVLTreeNode() :
val(T()), left(nullptr), right(nullptr), cnt(1), size(1), high(1) {}
AVLTreeNode(const T& v) :
val(v), left(nullptr), right(nullptr), cnt(1), size(1), high(1) {}
AVLTreeNode<T>* pushup() {
size = cnt + (left ? left->size : 0) + (right ? right->size : 0);
high = std::max(left ? left->high : 0, right ? right->high : 0) + 1;
return this;
}
};
每一次旋转、插入、删除时,我们复制一份 AVL 节点,见代码:
static node* left_rotate(node* p) {
// 左旋节点 p
node *q = p->left;
p->left = copy(p->left); // Added
p->left = q->right, q->right = p, p->pushup();
return q->pushup();
}
static node* right_rotate(node* p) {
// 右旋节点 p
node *q = p->right;
p->right = copy(p->right); // Added
p->right = q->left, q->left = p, p->pushup();
return q->pushup();
}
static node* left_right_rotate(node* p) {
p->left = right_rotate(p->left);
return left_rotate(p);
}
static node* right_left_rotate(node* p) {
p->right = left_rotate(p->right);
return right_rotate(p);
}
以下是完整版代码,对于普通 AVL 模板添加的地方做了标记。其实 AVL 也挺好写的。
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
// AVL 树节点类,维护树高high
// AVL 树通过旋转保证 |left->high - right->high| <= 1
template <class T>
struct AVLTreeNode {
T val;
AVLTreeNode<T> *left, *right;
long cnt, size, high;
AVLTreeNode() :
val(T()), left(nullptr), right(nullptr), cnt(1), size(1), high(1) {}
AVLTreeNode(const T& v) :
val(v), left(nullptr), right(nullptr), cnt(1), size(1), high(1) {}
AVLTreeNode<T>* pushup() {
size = cnt + (left ? left->size : 0) + (right ? right->size : 0);
high = std::max(left ? left->high : 0, right ? right->high : 0) + 1;
return this;
}
};
template <class T, class Cmp = std::less<T>>
struct AVLTree {
using node = AVLTreeNode<T>;
Cmp cmp = Cmp();
node *roots[N];
static node* copy(node* cur) { // 复制节点
node *res = nullptr;
if (cur) res = new node, *res = *cur;
return res;
}
static node* get_min(node* cur) {
node *x = cur;
while (x && x->left) x = x->left;
return x;
}
static node* get_max(node* cur) {
node *x = cur;
while (x && x->right) x = x->right;
return x;
}
static node* left_rotate(node* p) {
// 左旋节点 p
node *q = p->left;
p->left = copy(p->left); // Added
p->left = q->right, q->right = p, p->pushup();
return q->pushup();
}
static node* right_rotate(node* p) {
// 右旋节点 p
node *q = p->right;
p->right = copy(p->right); // Added
p->right = q->left, q->left = p, p->pushup();
return q->pushup();
}
static node* left_right_rotate(node* p) {
p->left = right_rotate(p->left);
return left_rotate(p);
}
static node* right_left_rotate(node* p) {
p->right = left_rotate(p->right);
return right_rotate(p);
}
static long get_high(node* p) {return p ? p->high : 0;}
void insert(node*& cur, const T& val) {
if (!cur) {
cur = new node(val);
return;
}
if (val == cur->val) {
cur->cnt++, cur->pushup();
return;
}
if (cmp(val, cur->val)) {
cur->left = copy(cur->left); // Added
insert(cur->left, val), cur->pushup();
if (get_high(cur->left) - get_high(cur->right) >= 2) {
cur = cmp(val, cur->left->val) ?
left_rotate(cur) : left_right_rotate(cur);
}
} else {
cur->right = copy(cur->right); // Added
insert(cur->right, val), cur->pushup();
if (get_high(cur->right) - get_high(cur->left) >= 2) {
cur = cmp(val, cur->right->val) ?
right_left_rotate(cur) : right_rotate(cur);
}
}
cur->pushup();
}
bool remove_node(node*& cur) {
if (!cur) return false;
if (cur->cnt > 1) {
cur->cnt--, cur->pushup();
return true;
}
if (cur->left && cur->right) {
node* replace = this->get_min(cur->right);
cur->cnt = replace->cnt, cur->val = replace->val;
replace->cnt = 1;
remove(cur->right, replace->val), cur->pushup();
if (get_high(cur->left) - get_high(cur->right) >= 2) {
cur = (get_high(cur->left->left) >= get_high(cur->left->right)) ?
left_rotate(cur) : left_right_rotate(cur);
}
} else {
cur = cur->left ? cur->left : cur->right;
}
if (cur) cur->pushup();
return true;
}
bool remove(node*& cur, const T& val) {
if (!cur) return false;
if (val == cur->val) return remove_node(cur);
bool res;
if (cmp(val, cur->val)) {
cur->left = copy(cur->left); // Added
res = remove(cur->left, val), cur->pushup();
if (get_high(cur->right) - get_high(cur->left) >= 2) {
cur = get_high(cur->right->right) >= get_high(cur->right->left) ?
right_rotate(cur) : right_left_rotate(cur);
}
} else {
cur->right = copy(cur->right); // Added
res = remove(cur->right, val), cur->pushup();
if (get_high(cur->left) - get_high(cur->right) >= 2) {
cur = get_high(cur->left->left) >= get_high(cur->left->right) ?
left_rotate(cur) : left_right_rotate(cur);
}
}
if (cur) cur->pushup();
return res;
}
int rank(node* cur, const T& val) const {
if (!cur) return 1;
int left_size = cur->left ? cur->left->size : 0;
if (val == cur->val) return left_size + 1;
if (cmp(val, cur->val)) return rank(cur->left, val);
return rank(cur->right, val) + left_size + cur->cnt;
}
T kth(node* cur, int k) const {
if (!cur) return T();
int left_size = cur->left ? cur->left->size : 0;
if (left_size >= k) return kth(cur->left, k);
if (left_size < k - cur->cnt) return kth(cur->right, k - left_size - cur->cnt);
return cur->val;
}
T predecessor(node *root, const T& val) const {
node *cur = root;
T res = -numeric_limits<T>::max();
while (cur) {
if (cmp(cur->val, val)) res = cur->val, cur = cur->right;
else cur = cur->left;
}
return res;
}
T successor(node *root, const T& val) const {
node *cur = root;
T res = numeric_limits<T>::max();
while (cur) {
if (cmp(val, cur->val)) res = cur->val, cur = cur->left;
else cur = cur->right;
}
return res;
}
};
AVLTree<int> avl;
int q, v, opt, x;
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);
cin >> q;
avl.roots[0] = nullptr;
for (int i = 1; i <= q; i++) {
cin >> v >> opt >> x;
avl.roots[i] = avl.copy(avl.roots[v]);
if (opt == 1) avl.insert(avl.roots[i], x);
else if (opt == 2) avl.remove(avl.roots[i], x);
else if (opt == 3) cout << avl.rank(avl.roots[i], x) << '\n';
else if (opt == 4) cout << avl.kth(avl.roots[i], x) << '\n';
else if (opt == 5) cout << avl.predecessor(avl.roots[i], x) << '\n';
else if (opt == 6) cout << avl.successor(avl.roots[i], x) << '\n';
//cerr << prettify(avl.roots[i]) << endl;
}
return 0;
}