(有代码)SA 建后缀树

· · 算法·理论

搜遍全网,没有找到一篇有代码的 SA 建后缀树博客,所以自己写一篇。

建议先观看前置博客 Luckyblock,下面是该博客单调栈维护右链做法的一种实现,时间复杂度 O(n)

注:

  1. 为了让每个后缀的结束节点都是后缀树的叶子,所以一般在 s 后面添加一个字符集外且比字符集内所有字符大的字符(如 |)。

  2. 设在下文之前已经 O(n) 地求出了添加字符后 ssa_{1...n}\in[1,n]h_{1...n},其中 h_i=\operatorname{lcp}(sa_i,sa_{i+1})

  3. 单调栈里的节点弹出时才统计 fa 是谁。

code:

int sa[N],h[N];
int sk[N],pos[N],fa[N];//sk: 单调栈, pos: 节点深度, fa: 节点的 father 
void build(string s){
    s+='|';
    getsa(s),geth();
    int top=0; 
    pos[0]=0;//根节点
    for(int i=0;i<n-1;i++){
        int lst=0;
        while(pos[sk[top]]>h[i]){
            lst=sk[top--];
            fa[lst]=sk[top];
        }
        //lca(sa[i-1],sa[i])=h[i], 所以深度大于 h_i 的都要弹出
        //弹出之后如果栈顶存在 pos[x]=h[i], 则 lca 已在栈中 
        //否则新建 lca 为 ++tot, 插入 lst 和 sk[top] 之间 
        if(h[i]!=pos[sk[top]]){
            fa[lst]=++tot;
            pos[tot]=h[i];
            sk[++top]=tot;
        }//没有 h_i, 需要新建
        sk[++top]=++tot;
        pos[tot]=n-sa[i+1]+1;
    }
    while(top){
        fa[sk[top]]=sk[top-1];
        top--;
    }// 不要忘记弹出剩余的节点 
}

例题 1:[TJOI2019] 甲苯先生和大中锋的字符串

建后缀树,统计每个节点的 szsz=k 则出现 k 次。每个节点都代表从某个点开始长度为一段区间的子串 s_{i...l},s_{i...l+1},...,s_{i...r},这些出现次数都加 1,差分维护即可。

PS:用 SA 建就是为了严格的 O(n),所以尽量减小常数,不要 dfs 统计,可以边建后缀树边统计或建完后拓扑排序。

PPS:煮波偷懒写的倍增 SA,平添 log ⌊悲⌋。

code

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int N=100010;
int T,n,k;
string s;
int h[N];
namespace SA{
    int n,rk[N],sa[N],tmp[N],bin[N];
    void SAI(){
        for(int i=1;i<=n;i++)tmp[i]=i,rk[i]=s[i-1];
        int p=128;
        auto Bsort=[&p](){
            fill(bin,bin+p+1,0);
            for(int i=1;i<=n;i++)bin[rk[i]]++;
            for(int i=1;i<=p;i++)bin[i]+=bin[i-1];
            for(int i=n;i;i--)sa[bin[rk[tmp[i]]]--]=tmp[i];
        };
        Bsort();
        for(int j=1;;j<<=1){
            int tot=0;
            for(int i=n-j+1;i<=n;i++)tmp[++tot]=i;
            for(int i=1;i<=n;i++)if(sa[i]>j)tmp[++tot]=sa[i]-j;
            Bsort();
            tmp[sa[1]]=1,p=1;
            for(int i=2;i<=n;i++)
                tmp[sa[i]]=(rk[sa[i]]==rk[sa[i-1]] and rk[sa[i]+j]==rk[sa[i-1]+j])?p:++p;
            copy(tmp,tmp+n+1,rk);
            if(p==n)return;
        }
    }
    int sk[N+N],pos[N+N],sz[N+N],fa[N+N],tot,b[N];//id
    bool vis[N<<1];
    void work(){
        tot=0;
        s=s+'|';
        n=s.size();
        SAI();
        fill(sz,sz+n+n+5,0);
        fill(fa,fa+n+n+5,0);
        fill(b,b+n+1,0);
        fill(vis,vis+n+n+1,0);
        for(int i=1;i<=n;i++)
            rk[sa[i]]=i;
        sa[n+1]=0;
        auto aug=[](int i,int j,int k){
            while(i+k<=n and j+k<=n and s[i+k-1]==s[j+k-1])k++;
            return k;
        };
        for(int i=1;i<=n;i++)
            h[rk[i]]=aug(i,sa[rk[i]+1],max(h[rk[i-1]]-1,0));
        // get SA & height
        int top=0;
        sz[0]=pos[0]=0;
        for(int i=0;i<n-1;i++){
            int lst=0;
            while(pos[sk[top]]>h[i]){
                lst=sk[top--];
                fa[lst]=sk[top];sz[fa[lst]]+=sz[lst];
            }
            if(h[i]!=pos[sk[top]]){
                sz[fa[lst]]-=sz[lst],
                fa[lst]=++tot;
                sz[tot]+=sz[lst],
                pos[tot]=h[i];
                sk[++top]=tot;
            } 
            sk[++top]=++tot;
            pos[tot]=n-sa[i+1]+1;
            sz[tot]=1;
        }
        while(top){
            fa[sk[top]]=sk[top-1];
            sz[fa[sk[top]]]+=sz[sk[top]];
            top--;
        }
        for(int i=0;i<=tot;i++)
            vis[fa[i]]=1;
//      assert(sz[0]==n-1);
        // build suffix tree
        for(int i=0;i<=tot;i++)
            if(sz[i]==k)
                b[pos[fa[i]]+1]++,b[pos[i]+vis[i]]--;
        int ans=0;
        for(int i=1;i<=n;i++)
            if((b[i]+=b[i-1])>=b[ans])ans=i;
        cout<<(!b[ans]?-1:ans)<<'\n';
    }
}
signed main(){
    cin.tie(nullptr)->sync_with_stdio(0);
    cin>>T;
    while(T--){
        cin>>s>>k;
        SA::work();
    }
    return 0;
}

例题 2:[TJOI2015] 弦论

建后缀树,维护每个节点的 sz,dep,建立该节点是哪个后缀。

按字典序遍历后缀树,每走到一个节点计算出它代表的长度为一个区间的子串的出现次数,总次数大于等于 k 时就输出。

时间复杂度 O(n)

code

#include<bits/stdc++.h>
using namespace std;
const int N=5e5+20;
string s;
int n,t,k;
int sa[N],rk[N],h[N];
namespace SA{
    inline void SAIS(int* s,int* st,bool* t,int n,int m){
        int top=0;
        t[n]=0;
        static int cnt[N],nc[N];
        for(int i=n-1;i>=1;--i)t[i]=(s[i]==s[i+1]?t[i+1]:s[i]>s[i+1]);
        for(int i=2;i<=n;++i)if(t[i-1]&&!t[i])st[rk[i]=++top]=i;else rk[i]=0;
        const auto _sort=[&](int* st){
            fill(sa+1,sa+n+1,0);
            fill(cnt,cnt+m+1,0);
            for(int i=1;i<=n;++i)++cnt[s[i]];
            for(int i=1;i<=m;++i)cnt[i]+=cnt[i-1];
            copy(cnt,cnt+m+1,nc);
            for(int i=top;i;--i)sa[nc[s[st[i]]]--]=st[i];
            for(int i=1;i<=m;++i)nc[i]=cnt[i-1]+1;
            for(int i=1;i<=n;++i)
                if(sa[i]-1>0&&t[sa[i]-1])sa[nc[s[sa[i]-1]]++]=sa[i]-1;
            copy(cnt,cnt+m+1,nc);
            for(int i=n;i>=1;--i)
                if(sa[i]-1>0&&!t[sa[i]-1])sa[nc[s[sa[i]-1]]--]=sa[i]-1;
        };
        _sort(st);
        int q=0,*qs=s+n+1,*qst=st+n+1;bool *qt=t+n+1;
        for(int i=1,x=0,y=0;i<=n;++i)if(x=rk[sa[i]]){
            if(!m||st[x+1]-st[x]!=st[y+1]-st[y])++q;
            else for(int l=st[x],r=st[y];l<=st[x+1];++l,++r)if(s[l]!=s[r]){++q;break;}
            qs[y=x]=q;
        }
        if(q<top)SAIS(qs,qst,qt,top,q);
        else for(int i=1;i<=top;++i)sa[qs[i]]=i;
        for(int i=1;i<=top;++i)qs[i]=st[sa[i]];
        _sort(qs);
    }
    int str[N<<1],st[N<<1];
    bool t[N<<1];
    void work(string s){
        n=s.size();
        for(int i=0;i<n;i++)str[i+1]=s[i]+1;
        str[n+1]=0;
        SAIS(str,st,t,n+1,128);
        for(int i=1;i<=n;i++)
            rk[sa[i]=sa[i+1]]=i;
        sa[n+1]=0;
        auto aug=[](int i,int j,int k){
            while(i+k<=n and j+k<=n and str[i+k]==str[j+k])k++;
            return k;
        };
        for(int i=1;i<=n;i++)
            h[rk[i]]=aug(i,sa[rk[i]+1],max(h[rk[i-1]]-1,0));
    }
}
int tot;
int sz[N<<1],pos[N<<1],sk[N<<1],fa[N<<1],bg[N<<1];
vector<int>vc[N<<1];
void build(){
    int top=0;
    sz[0]=pos[0]=0;
    for(int i=0;i<n-1;i++){
        int lst=0;
        while(pos[sk[top]]>h[i]){
            lst=sk[top--];
            fa[lst]=sk[top],sz[fa[lst]]+=sz[lst];
        }
        if(h[i]!=pos[sk[top]]){
            sz[fa[lst]]-=sz[lst],fa[lst]=++tot;
            sz[tot]+=sz[lst],pos[tot]=h[i];
            sk[++top]=tot;bg[tot]=sa[i+1];
        }
        sk[++top]=++tot;
        pos[tot]=n-sa[i+1]+1;
        sz[tot]=1;bg[tot]=sa[i+1]; 
    }
    while(top){
        fa[sk[top]]=sk[top-1];
        sz[sk[top-1]]+=sz[sk[top]];
        top--;
    }
    for(int i=1;i<=tot;i++)
        vc[fa[i]].push_back(i);
}//build suffix tree
void dfs(int x,int fa){
    if(vc[x].empty())pos[x]--;
    if(k<=1ll*(pos[x]-pos[fa])*sz[x]){
        cout<<s.substr(bg[x]-1,pos[fa]+(k+sz[x]-1)/sz[x]);
        exit(0);
    }
    k-=(pos[x]-pos[fa])*sz[x];
    for(int e:vc[x])
        dfs(e,x);
}
signed main(){
//  freopen("in.in","r",stdin);
//  freopen("out.out","w",stdout);
    cin.tie(nullptr)->sync_with_stdio(0);
    cin>>s>>t>>k;
    if(!t){
        SA::work(s);
        for(int i=1;i<=n;i++){
            if(k<=n-sa[i]+1-h[i-1]){
                cout<<s.substr(sa[i]-1,h[i-1]+k);
                return 0;
            }
            k-=n-sa[i]+1-h[i-1];
        }
        cout<<-1;
        return 0;
    }
    s+='|';
    SA::work(s);
    build();
    dfs(0,0);
    cout<<-1;
    return 0;
}