可持久化 AVL

· · 题解

现有的在线题解:FHQ-Treap、01-Trie、WBLT、线段树、替罪羊(复杂度是假的),和 有旋 Treap

那么,既然有旋 Treap 可持久化,同样使用旋转维护平衡的 AVL 树也可以可持久化。

先介绍 AVL 树的原理:AVL 树每个节点维护树高 high,需要通过旋转操作保证 |high_l-high_r| \le 1。可以证明,满足这一性质的 AVL 树高为严格 O(\log n)

AVL 树的节点定义如下:

template <class T>
struct AVLTreeNode {
    T val;
    AVLTreeNode<T> *left, *right;
    long cnt, size, high;

    AVLTreeNode() :
    val(T()), left(nullptr), right(nullptr), cnt(1), size(1), high(1) {}

    AVLTreeNode(const T& v) :
    val(v), left(nullptr), right(nullptr), cnt(1), size(1), high(1) {}

    AVLTreeNode<T>* pushup() {
        size = cnt + (left ? left->size : 0) + (right ? right->size : 0);
        high = std::max(left ? left->high : 0, right ? right->high : 0) + 1;
        return this;
    }
};

每一次旋转、插入、删除时,我们复制一份 AVL 节点,见代码:

static node* left_rotate(node* p) {
    // 左旋节点 p
    node *q = p->left;
    p->left = copy(p->left);   // Added
    p->left = q->right, q->right = p, p->pushup();
    return q->pushup();
}
static node* right_rotate(node* p) {
    // 右旋节点 p
    node *q = p->right;
    p->right = copy(p->right);   // Added
    p->right = q->left, q->left = p, p->pushup();
    return q->pushup();
}
static node* left_right_rotate(node* p) {
    p->left = right_rotate(p->left);
    return left_rotate(p);
}
static node* right_left_rotate(node* p) {
    p->right = left_rotate(p->right);
    return right_rotate(p);
}

以下是完整版代码,对于普通 AVL 模板添加的地方做了标记。其实 AVL 也挺好写的。

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;

// AVL 树节点类,维护树高high
// AVL 树通过旋转保证 |left->high - right->high| <= 1
template <class T>
struct AVLTreeNode {
    T val;
    AVLTreeNode<T> *left, *right;
    long cnt, size, high;

    AVLTreeNode() :
    val(T()), left(nullptr), right(nullptr), cnt(1), size(1), high(1) {}

    AVLTreeNode(const T& v) :
    val(v), left(nullptr), right(nullptr), cnt(1), size(1), high(1) {}

    AVLTreeNode<T>* pushup() {
        size = cnt + (left ? left->size : 0) + (right ? right->size : 0);
        high = std::max(left ? left->high : 0, right ? right->high : 0) + 1;
        return this;
    }
};

template <class T, class Cmp = std::less<T>> 
struct AVLTree {
    using node = AVLTreeNode<T>;
    Cmp cmp = Cmp();
    node *roots[N];

    static node* copy(node* cur) {   // 复制节点
        node *res = nullptr;
        if (cur) res = new node, *res = *cur;
        return res;
    }

    static node* get_min(node* cur) {
        node *x = cur;
        while (x && x->left) x = x->left;
        return x;
    }
    static node* get_max(node* cur) {
        node *x = cur;
        while (x && x->right) x = x->right;
        return x;
    }

    static node* left_rotate(node* p) {
        // 左旋节点 p
        node *q = p->left;
        p->left = copy(p->left);   // Added
        p->left = q->right, q->right = p, p->pushup();
        return q->pushup();
    }
    static node* right_rotate(node* p) {
        // 右旋节点 p
        node *q = p->right;
        p->right = copy(p->right);   // Added
        p->right = q->left, q->left = p, p->pushup();
        return q->pushup();
    }
    static node* left_right_rotate(node* p) {
        p->left = right_rotate(p->left);
        return left_rotate(p);
    }
    static node* right_left_rotate(node* p) {
        p->right = left_rotate(p->right);
        return right_rotate(p);
    }
    static long get_high(node* p) {return p ? p->high : 0;}

    void insert(node*& cur, const T& val) {
        if (!cur) {
            cur = new node(val);
            return;
        }
        if (val == cur->val) {
            cur->cnt++, cur->pushup();
            return;
        }
        if (cmp(val, cur->val)) {
            cur->left = copy(cur->left);    // Added
            insert(cur->left, val), cur->pushup();
            if (get_high(cur->left) - get_high(cur->right) >= 2) {
                cur = cmp(val, cur->left->val) ? 
                left_rotate(cur) : left_right_rotate(cur);
            }
        } else {
            cur->right = copy(cur->right);   // Added
            insert(cur->right, val), cur->pushup();
            if (get_high(cur->right) - get_high(cur->left) >= 2) {
                cur = cmp(val, cur->right->val) ? 
                right_left_rotate(cur) : right_rotate(cur);
            }
        }
        cur->pushup();
    }

    bool remove_node(node*& cur) {
        if (!cur) return false;
        if (cur->cnt > 1) {
            cur->cnt--, cur->pushup();
            return true;
        }
        if (cur->left && cur->right) {
            node* replace = this->get_min(cur->right);
            cur->cnt = replace->cnt, cur->val = replace->val;
            replace->cnt = 1;
            remove(cur->right, replace->val), cur->pushup();
            if (get_high(cur->left) - get_high(cur->right) >= 2) {
                cur = (get_high(cur->left->left) >= get_high(cur->left->right)) ?
                left_rotate(cur) : left_right_rotate(cur);
            }
        } else {
            cur = cur->left ? cur->left : cur->right;
        }
        if (cur) cur->pushup();
        return true;
    }
    bool remove(node*& cur, const T& val) {
        if (!cur) return false;
        if (val == cur->val) return remove_node(cur);
        bool res;
        if (cmp(val, cur->val)) {
            cur->left = copy(cur->left);   // Added
            res = remove(cur->left, val), cur->pushup();
            if (get_high(cur->right) - get_high(cur->left) >= 2) {
                cur = get_high(cur->right->right) >= get_high(cur->right->left) ? 
                right_rotate(cur) : right_left_rotate(cur);
            }
        } else {
            cur->right = copy(cur->right);   // Added
            res = remove(cur->right, val), cur->pushup();
            if (get_high(cur->left) - get_high(cur->right) >= 2) {
                cur = get_high(cur->left->left) >= get_high(cur->left->right) ? 
                left_rotate(cur) : left_right_rotate(cur);
            }
        }
        if (cur) cur->pushup();
        return res;
    }

    int rank(node* cur, const T& val) const {
        if (!cur) return 1;
        int left_size = cur->left ? cur->left->size : 0;
        if (val == cur->val) return left_size + 1;
        if (cmp(val, cur->val)) return rank(cur->left, val);
        return rank(cur->right, val) + left_size + cur->cnt;
    }
    T kth(node* cur, int k) const {
        if (!cur) return T();
        int left_size = cur->left ? cur->left->size : 0;
        if (left_size >= k) return kth(cur->left, k);
        if (left_size < k - cur->cnt) return kth(cur->right, k - left_size - cur->cnt);
        return cur->val;
    }

    T predecessor(node *root, const T& val) const {
        node *cur = root;
        T res = -numeric_limits<T>::max();
        while (cur) {
            if (cmp(cur->val, val)) res = cur->val, cur = cur->right;
            else cur = cur->left;
        }
        return res;
    }
    T successor(node *root, const T& val) const {
        node *cur = root;
        T res = numeric_limits<T>::max();
        while (cur) {
            if (cmp(val, cur->val)) res = cur->val, cur = cur->left;
            else cur = cur->right;
        }
        return res;
    }
};

AVLTree<int> avl;
int q, v, opt, x;

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    cin >> q;
    avl.roots[0] = nullptr;
    for (int i = 1; i <= q; i++) {
        cin >> v >> opt >> x;
        avl.roots[i] = avl.copy(avl.roots[v]);
        if (opt == 1) avl.insert(avl.roots[i], x);
        else if (opt == 2) avl.remove(avl.roots[i], x);
        else if (opt == 3) cout << avl.rank(avl.roots[i], x) << '\n';
        else if (opt == 4) cout << avl.kth(avl.roots[i], x) << '\n';
        else if (opt == 5) cout << avl.predecessor(avl.roots[i], x) << '\n';
        else if (opt == 6) cout << avl.successor(avl.roots[i], x) << '\n';
        //cerr << prettify(avl.roots[i]) << endl;
    }
    return 0;
}