题解 P2408 【不同子串个数】

· · 题解

字符串裸题,明显可以用SAM做(当然SA也是可以的),可以用来练习SAM。

方法是建完SAM后,直接在上面跑dp,求出路径数量,便可以求出子串数,因为SAM的每条路径都对应着不重复的子串。dp方程也好写:

dp[u]=1+\sum_{(u,v)\in DAWG}dp[v]

就是求出树上路径数的方法。

代码:

    #include<bits/stdc++.h>
    #include<cctype>
    #define For(i,a,b) for(i=(a),i##end=(b);i<=i##end;++i)
    #define Forward(i,a,b) for(i=(a),i##end=(b);i>=i##end;--i)
    #define Rep(i,a,b) for(register int i=(a),i##end=(b);i<=i##end;++i)
    #define Repe(i,a,b) for(register int i=(a),i##end=(b);i>=i##end;--i)
    using namespace std;
    template<typename T>inline void read(T &x){
        T s=0,f=1;char k=getchar();
        while(!isdigit(k)&&k^'-')k=getchar();
        if(!isdigit(k)){f=-1;k=getchar();}
        while(isdigit(k)){s=s*10+(k^48);k=getchar();}
        x=s*f;
    }
    void file(void){
        #ifndef ONLINE_JUDGE
        freopen("DAWG.in","r",stdin);
        freopen("DAWG.out","w",stdout);
        #endif
    }
    const int MAXN=1e6+7;
    static struct DAWG
    {
        int len,link,nxt[30];
        long long cnt;
    }p[MAXN<<1];
    char s[MAXN];
    static int e,n;
    #define Chkmax(a,b) a=a>b?a:b
    inline void extend(int c)
    {
        static int last=0,j,cur,q,clone;
        p[cur=++e].len=p[last].len+1;
        for(j=last;~j&&!p[j].nxt[c];j=p[j].link)p[j].nxt[c]=cur;
        if(j==-1)p[cur].link=0;
        else
        {
            q=p[j].nxt[c];
            if(p[j].len+1==p[q].len)p[cur].link=q;
            else
            {
                clone=++e;
                p[clone].len=p[j].len+1;
                memcpy(p[clone].nxt,p[q].nxt,sizeof p[q].nxt);
                p[clone].link=p[q].link;
                for(;~j&&p[j].nxt[c]==q;j=p[j].link)
                    p[j].nxt[c]=clone;
                p[q].link=p[cur].link=clone;
            }
        }
        last=cur;
    }
    inline bool cmp(int a,int b){return p[a].len>p[b].len;}
    int b[MAXN<<1];
    int main()
    {
        file();
        read(n);
        p[0].link=-1;
        scanf("%s",s);
        Rep(i,0,n-1)extend(s[i]-'a');
        Rep(i,0,e)b[i]=i;
        sort(b,b+e+1,cmp);
        Rep(i,0,e)Rep(j,0,25)
        if(p[b[i]].nxt[j])p[b[i]].cnt+=p[p[b[i]].nxt[j]].cnt+1;
        printf("%lld\n",p[0].cnt);
        return 0;
    }