题解 P6164 【【模板】后缀平衡树】

· · 题解

终于抢到第一篇题解了,真不容易

后缀平衡树

前置知识

替罪羊树

原理

后缀平衡树是一个以字典序为关键字的有序字符串集X,它可以资瓷两种复杂度为O(\log n)的操作:

所以,如果对一个字符串逆序执行插入操作,它就是严格的“后缀平衡树”,它的中序遍历也就是后缀数组。但其实上,观察这两个操作,它完全可以维护一个有根树森林。

实现

两种插入操作完全可以视为一种。我们考虑要加入一个字符串xS,我们要做的就是从根开始与当前节点进行比较,然后进入左/右儿子,直到找到一个空位为止。现在的问题就是,如何比较要插入的字符串和节点上的字符串。

显然我们不能O(n)暴力比较。考虑到要插入的字符串xS,除去第一个字母之后的S已经在树中,并且树是有序的。那么当第一个字符相同时,我们完全可以利用这棵树进行O(\log n)比较后面的部分,这样的话,插入操作时O(\log ^ 2 n)的。

考虑如何进行O(1)比较,来做到单次操作是O(\log n)。我们对于每个节点维护一个key值代表在树中的相对位置。具体方法是:选取一个很大的值域,每次进入下一层时根据左右儿子将值域折半,最终节点的值就是当前值域的mid值。当需要比较两个已经在树中的字符串时,直接比较key值即可,复杂度O(1)

对于如何维护平衡,最简单的方法就是使用替罪羊树,注意在重构时要重新赋key值。

本题做法

由于后缀平衡树维护串头操作比较容易,所以对于这种串尾操作,我们可以直接维护反串的后缀平衡树,添加和删除操作正常维护,注意删除时彻底删除节点而不是使用懒标记,这样更不容易犯错,方便后面的查询。

查询时,假设当前串为S,查询以TS中的出现次数,就可以转化为查询有多少个S的后缀是以T为前缀的。我们把T翻转,在后面添加一个字典序极大的字符'Z'+1,查询它的排名为r。然后我们保留后添加的字符,把上一个字符-1,再次查询它的排名为l。这样,得到的r-l就是TS中的出现次数。

代码如下

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;
const double LIM = 1e16;
const int MAXN = 2e6;

int n, m;
char str[MAXN], S[MAXN];
double key[MAXN];
int ch[MAXN][2], siz[MAXN];
int tr[MAXN], rt, tcnt;
char op[10];

void Decode(char *s, int mask) {
    int len = strlen(s);
    for (int i = 0; i < len; i++) {
        mask = (mask * 131 + i) % len;
        char t = s[i];
        s[i] = s[mask];
        s[mask] = t;
    }
}

void Update(int now) {
    siz[now] = 1 + siz[ch[now][0]] + siz[ch[now][1]];
}

int Bad(int now) {
    return 1.0 * siz[ch[now][0]] > 0.7 * siz[now] || 1.0 * siz[ch[now][1]] > 0.7 * siz[now];
}

void DFS(int now) {
    if (!now) return;
    DFS(ch[now][0]);
    tr[++tcnt] = now;
    DFS(ch[now][1]);
    ch[now][0] = ch[now][1] = 0;
}

void Rebuild(int &now, int l, int r, double lv, double rv) {
    if (l > r) return;
    int mid = (l + r) >> 1;
    double midv = (lv + rv) / 2;
    now = tr[mid];
    key[now] = midv;
    Rebuild(ch[now][0], l, mid - 1, lv, midv);
    Rebuild(ch[now][1], mid + 1, r, midv, rv);
    Update(now);
}

void Maintain(int &now, double lv, double rv) {
    if (Bad(now)) {
        tcnt = 0;
        DFS(now);
        Rebuild(now, 1, tcnt, lv, rv);
    }
}

int Comp(int x, int y) {
    return S[x] < S[y] || (S[x] == S[y] && key[x - 1] < key[y - 1]);
}

void Insert(int &now, int idx, double lv, double rv) {
    if (!now) {
        now = idx;
        siz[now] = 1;
        key[now] = (lv + rv) / 2;
        ch[now][0] = ch[now][1] = 0;
        return;
    }
    if (Comp(idx, now)) Insert(ch[now][0], idx, lv, key[now]);
    else Insert(ch[now][1], idx, key[now], rv);
    Update(now);
    Maintain(now, lv, rv);
}

void Remove(int &now, int idx) {
    if (now == idx) {
        if (!ch[now][0] || !ch[now][1]) {
            now = (ch[now][0] | ch[now][1]);
        } else {
            int cur = ch[now][0], las = now;
            while (ch[cur][1]) {
                las = cur;
                siz[las]--;
                cur = ch[cur][1];
            }
            if (las == now) {
                ch[cur][1] = ch[now][1];
                now = cur;
                Update(now);
            } else {
                ch[cur][0] = ch[now][0];
                ch[cur][1] = ch[now][1];
                ch[las][1] = 0;
                now = cur;
                Update(now);
            }
        }
        return;
    }
    if (Comp(idx, now)) Remove(ch[now][0], idx);
    else Remove(ch[now][1], idx);
    Update(now);
}

int Com(int now) {
    for (int p = 1; str[p]; p++, now = (now ? now - 1 : 0)) {
        if (str[p] < S[now]) return 1;
        else if (str[p] > S[now]) return 0;
    }
}

int Query(int now) {
    if (!now) return 0;
    int ls = siz[ch[now][0]];
    if (Com(now)) return Query(ch[now][0]);
    else return Query(ch[now][1]) + ls + 1;
}

int main() {
    scanf("%d", &m);
    scanf("%s", S + 1);
    int mask = 0, ans = 0;
    n = strlen(S + 1);
    for (int i = 1; i <= n; i++) Insert(rt, i, 0, LIM);
    for (int i = 1; i <= m; i++) {
        scanf("%s", op);
        if (op[0] == 'A') {
            scanf("%s", str + 1);
            Decode(str + 1, mask);
            int len = strlen(str + 1);
            for (int j = 1; j <= len; j++) {
                S[n + j] = str[j];
                Insert(rt, n + j, 0, LIM);
            }
            n += len;
        } else if (op[0] == 'Q') {
            scanf("%s", str + 1);
            Decode(str + 1, mask);
            int len = strlen(str + 1);
            reverse(str + 1, str + len + 1);
            str[len + 1] = 'Z' + 1;
            str[len + 2] = '\0';
            ans = Query(rt);
            str[len]--;
            ans -= Query(rt);
            printf("%d\n", ans);
            mask ^= ans;
        } else {
            int k;
            scanf("%d", &k);
            for (int j = n; j > n - k; j--) Remove(rt, j);
            n -= k;
        }
    }
    return 0;
}