题解 P5212 【SubString】

· · 题解

个人认为比其他题解写的都更详细更明白/kel

考虑字符串 s 出现的次数,在SAM中,一个节点里面的某个子串的出现次数就是它的子树的出现次数和,因为长的后缀与短的后缀之间信息不共享,所以修改操作本质上是在进行 parent 树上的链加。

考虑一种神奇的写法。每次对于新建的节点 np ,他的贡献应该是 parent 树上 1\sim np 这条路径上的所有点。于是考虑先 merge(1, np) ,把 npsplay 上去之后内部就变成了一棵以 np 为根的一条链,这样就可以不用考虑链加,直接在 np 处打标记即可。

似乎查询操作更为神奇。因为查询的时候只需要对于走到的一个点 x ,直接把他 splay 掉就可以维护信息。看上去似乎不是很对,因为对 x 产生贡献的是一颗子树而不是一条链。但这样做其实有他独特的正确性保证,即每个点都存在且仅存在于一棵 splay ,换 splay 的时候势必要 access,而 access 时本质上就已经把原来的标记给下放干净了,所以每次只有可能是当前的 splay 还有信息没有维护清楚。也就是每次只需要管一条链,剩下的链的标记已经清完了。这样就只需要 splay 一下即可。

写的时候,为了卡常发现了个更神奇的地方,就是在SAM里面抠点插子树/插点这两个操作,由于都保证了父亲不存在,所以 Link 这个操作,本质上是不需要 make_root 的,实测这样就会快很多很多。

#include <cstdio>
#include <cstring>
#include <iostream>

#define il inline
#define fa(x) t[x].fa
#define rv(x) t[x].rev
#define vl(x) t[x].val
#define tg(x) t[x].tag
#define lc(x) t[x].son[0]
#define rc(x) t[x].son[1]

using namespace std ;

const int N = 1000010 ;
const int M = 1500010 ;

struct LCT{
    int fa ;
    int rev ;
    int val ;
    int tag ;
    int son[2] ;
}t[M] ;
int n ;
int m ;
int tp ;
int len ;
int ans ;
int res ;
char s[N] ;
char o[N] ;
int stk[M] ;

bool nroot(int x){
    return ((lc(fa(x)) == x) || (rc(fa(x)) == x)) ;
}
void reverse(int x){
    rv(x) ^= 1 ; swap(lc(x), rc(x)) ;
}
void _down(int x){
    if (rv(x)){
        rv(x) = 0 ;
        if (lc(x)) reverse(lc(x)) ;
        if (rc(x)) reverse(rc(x)) ;
    }
    if (tg(x)){
        if (lc(x)) vl(lc(x)) += tg(x), tg(lc(x)) += tg(x) ;
        if (rc(x)) vl(rc(x)) += tg(x), tg(rc(x)) += tg(x) ;
        tg(x) = 0 ;
    }
}
void rotate(int x){
    bool w = x == rc(fa(x)) ;
    int f1 = fa(x), f2 = fa(f1) ;
    if (nroot(f1))
        t[f2].son[f1 == rc(f2)] = x ;
    t[f1].son[w] = t[x].son[w ^ 1] ;
    fa( t[x].son[w ^ 1] ) = f1 ;
    fa(x) = f2 ; fa(f1) = x ;
    t[x].son[w ^ 1] = f1 ;
}
void splay(int x){
    int y = x ;
    stk[++ tp] = y ;
    while (nroot(y))
        stk[++ tp] = (y = fa(y)) ;
    while (tp) _down(stk[tp --]) ;
    while (nroot(x)){
        int f1 = fa(x) ;
        int f2 = fa(f1) ;
        if (nroot(f1)){
            if ((rc(f1) == x) == (rc(f2) == f1))
                rotate(f1) ; else rotate(x) ;
        }
        rotate(x) ;
    }
}
il void access(int x){
    int y = 0 ;
    for ( ; x ; x = fa(y = x))
        splay(x), rc(x) = y ;
}
il void make_root(int x){
    access(x) ;
    splay(x) ; reverse(x) ;
}
il int find_root(int x){
    access(x) ; splay(x) ;
    while(lc(x)) x = lc(x) ;
    return x ;
}
il void merge(int x, int y){
    make_root(x) ;
    access(y) ; splay(y) ;
}
il void link(int x, int y){
    fa(x) = y ;
}
il void cut(int x, int y){
    merge(x, y) ;
    fa(x) = t[y].son[0] = 0 ;
}
il void Input(int mk) {
    len = strlen(s) ;
    for (int j = 0 ; j < len ; ++ j)
        mk = (mk * 131 + j) % len, swap(s[j], s[mk]) ;
}
struct SAM{
    int size ;
    int last ;
    int len[M] ;
    int fal[M] ;
    int trans[M][2] ;
    void Init(){
        last = ++ size ;
    }
    void Ins(int x){
        int np = ++ size ;
        int q, nq, p = last ;
        len[last = np] = len[p] + 1 ;
        while (p && !trans[p][x])
            trans[p][x] = np, p = fal[p] ;
        if (!p){
            fal[np] = 1 ;
            link(np, 1), merge(1, np) ;
            vl(np) ++, tg(np) ++ ; return ;
        }
        q = trans[p][x] ;
        if (len[q] == len[p] + 1){
            fal[np] = q ;
            link(np, q), merge(1, np) ;
            vl(np) ++, tg(np) ++ ; return ;
        }
        nq = ++ size ;
        cut(fal[q], q) ;
        fal[nq] = fal[q] ;
        link (q, nq) ;
        link (np, nq) ;
        link (nq, fal[q]) ;
        len[nq] = len[p] + 1 ;
        fal[q] = fal[np] = nq ;
        splay(q) ; vl(nq) = vl(q) ;
        memcpy(trans[nq], trans[q], 8) ;
        while (p && trans[p][x] == q)
            trans[p][x] = nq, p = fal[p] ;
        merge(1, np), vl(np) ++, tg(np) ++ ;
    }
}S ;
int main(){
    cin >> m ;
    scanf("%s", s + 1) ;
    len = strlen(s + 1) ; S.Init() ;
    for (int i = 1 ; i <= len ; ++ i) S.Ins(s[i] - 'A') ;
    while (m --){
        scanf("%s", o + 1) ;
        scanf("%s", s), Input(ans) ;
        if (o[1] == 'A'){
            for (int i = 0 ; i < len ; ++ i)
                S.Ins(s[i] - 'A') ; //, cout << i << endl ;
        }
        else{
            int rt = 1 ;
            for (int i = 0 ; i < len ; ++ i){
                rt = S.trans[rt][s[i] - 'A'] ;
                if (!rt) break ;
            }
            if (!rt) res = 0 ;
            else splay(rt), res = vl(rt) ;
            printf("%d\n", res) ; ans = ans ^ res ;
        }
    }
    return 0 ;
}