P2408不同子串个数(题解)

· · 题解

PS.

此题是字符串入门好题。
显然用AC自动机、KMP等提高组算法不可解。
显然用后缀数组、后缀自动机等算法可以轻易解。
此篇题解讲了SA和SAM的两种做法。
仍然:码风压行,不喜勿喷。
求赞嘤嘤嘤

Problem.

求一个字符串的不同子串有多少个。

Part 1.

Solution.

用SA(后缀数组)做这题。
我们再SA中求出了Height数组,是相邻两个前缀的最长相同前缀。
那么我们只需要在当前答案中减去这个相同前缀的数量的贡献。
就能求出不同的子串数量了。

Coding.

#include<bits/stdc++.h>
using namespace std;
char s[1000005];int n,m,rk[1000005],sa[1000005],tax[1000005],tp[1000005],h[1000005];
//注意毒瘤码风:这里的h就是Height了,不是那个辅助数组H。
inline void srt()//后缀数组的桶排
{
    memset(tax,0,sizeof(tax));
    for(int i=1;i<=n;i++) tax[rk[tp[i]]]++;
    for(int i=1;i<=m;i++) tax[i]+=tax[i-1];
    for(int i=n;i>=1;i--) sa[tax[rk[tp[i]]]--]=tp[i];
}
inline void work()//求后缀数组
{
    for(int i=1;i<=n;i++) rk[i]=s[i],tp[i]=i;
    m=127,srt();
    for(int w=1,p=1,i;p<n;w<<=1,m=p)
    {
        for(p=0,i=n-w+1;i<=n;i++) tp[++p]=i;
        for(i=1;i<=n;i++) if(sa[i]>w) tp[++p]=sa[i]-w;
        srt(),memcpy(tp,rk,sizeof(rk)),rk[sa[1]]=p=1;
        for(i=2;i<=n;i++) rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+w]==tp[sa[i-1]+w]?p:++p);
    }
    for(int i=1,j=0,k;i<=n;h[rk[i]]=j,i++) for(j=(j?j-1:j),k=sa[rk[i]-1];s[i+j]==s[k+j];j++);
}
inline long long solve() {long long ans=1ll*n*(n+1)/2;for(int i=1;i<=n;i++) ans-=h[rk[i]];return ans;}
//solve就是统计答案。
//这里用了正难则反,总共有n*(n+1)/2个点,减去重复的就是答案了。
int main() {return scanf("%d%s",&n,s+1),work(),printf("%lld\n",solve()),0;}
//简洁的主程序,这次压行有点过分,主程序都压成一行了,我自裁。

Part 2

Solution.

用SAM(后缀自动机)来做这道题。
首先建立后缀自动机。
然后根据后缀自动机的性质我们可以发现:
在后缀自动机上从根节点开始的每一条路径都是一个子串。
所以直接在SAM上跑DP就好了,其实就是一遍DFS,因为SAM是一个DAG。

Coding.

#include<bits/stdc++.h>
using namespace std;
struct node{int len,fa,s[26];}a[200005];//SAM结构体
int n,lst=1,cnt=1;char s[100005];long long ans[200005];//ans就是dp数组
inline void ins(int c)
{//经典的后缀自动机的加点过程。
    int p=lst,np=lst=++cnt;a[np].len=a[p].len+1;
    for(;p&&!a[p].s[c];p=a[p].fa) a[p].s[c]=np;
    if(!p) return(void)(a[np].fa=1);
    int q=a[p].s[c];
    if(a[q].len==a[p].len+1) return(void)(a[np].fa=q);
    int nq=++cnt;a[nq]=a[q],a[nq].len=a[p].len+1,a[q].fa=a[np].fa=nq;
    for(;p&&a[p].s[c]==q;p=a[p].fa) a[p].s[c]=nq;
}
inline long long dfs(int x)
{//在SAM上跑dfs
    if(ans[x]) return ans[x];
    for(int i=0;i<26;i++) if(a[x].s[i]) ans[x]+=dfs(a[x].s[i])+1;
    return ans[x];
}
int main()
{
    scanf("%d%s",&n,s+1),memset(ans,0,sizeof(ans));
    for(int i=1;i<=n;i++) ins(s[i]-'a');
    return printf("%lld\n",dfs(1)),0;
}