P2408不同子串个数(题解)
PS.
此题是字符串入门好题。
显然用AC自动机、KMP等提高组算法不可解。
显然用后缀数组、后缀自动机等算法可以轻易解。
此篇题解讲了SA和SAM的两种做法。
仍然:码风压行,不喜勿喷。
求赞嘤嘤嘤
Problem.
求一个字符串的不同的子串有多少个。
Part 1.
Solution.
用SA(后缀数组)做这题。
我们再SA中求出了Height数组,是相邻两个前缀的最长相同前缀。
那么我们只需要在当前答案中减去这个相同前缀的数量的贡献。
就能求出不同的子串数量了。
Coding.
#include<bits/stdc++.h>
using namespace std;
char s[1000005];int n,m,rk[1000005],sa[1000005],tax[1000005],tp[1000005],h[1000005];
//注意毒瘤码风:这里的h就是Height了,不是那个辅助数组H。
inline void srt()//后缀数组的桶排
{
memset(tax,0,sizeof(tax));
for(int i=1;i<=n;i++) tax[rk[tp[i]]]++;
for(int i=1;i<=m;i++) tax[i]+=tax[i-1];
for(int i=n;i>=1;i--) sa[tax[rk[tp[i]]]--]=tp[i];
}
inline void work()//求后缀数组
{
for(int i=1;i<=n;i++) rk[i]=s[i],tp[i]=i;
m=127,srt();
for(int w=1,p=1,i;p<n;w<<=1,m=p)
{
for(p=0,i=n-w+1;i<=n;i++) tp[++p]=i;
for(i=1;i<=n;i++) if(sa[i]>w) tp[++p]=sa[i]-w;
srt(),memcpy(tp,rk,sizeof(rk)),rk[sa[1]]=p=1;
for(i=2;i<=n;i++) rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+w]==tp[sa[i-1]+w]?p:++p);
}
for(int i=1,j=0,k;i<=n;h[rk[i]]=j,i++) for(j=(j?j-1:j),k=sa[rk[i]-1];s[i+j]==s[k+j];j++);
}
inline long long solve() {long long ans=1ll*n*(n+1)/2;for(int i=1;i<=n;i++) ans-=h[rk[i]];return ans;}
//solve就是统计答案。
//这里用了正难则反,总共有n*(n+1)/2个点,减去重复的就是答案了。
int main() {return scanf("%d%s",&n,s+1),work(),printf("%lld\n",solve()),0;}
//简洁的主程序,这次压行有点过分,主程序都压成一行了,我自裁。
Part 2
Solution.
用SAM(后缀自动机)来做这道题。
首先建立后缀自动机。
然后根据后缀自动机的性质我们可以发现:
在后缀自动机上从根节点开始的每一条路径都是一个子串。
所以直接在SAM上跑DP就好了,其实就是一遍DFS,因为SAM是一个DAG。
Coding.
#include<bits/stdc++.h>
using namespace std;
struct node{int len,fa,s[26];}a[200005];//SAM结构体
int n,lst=1,cnt=1;char s[100005];long long ans[200005];//ans就是dp数组
inline void ins(int c)
{//经典的后缀自动机的加点过程。
int p=lst,np=lst=++cnt;a[np].len=a[p].len+1;
for(;p&&!a[p].s[c];p=a[p].fa) a[p].s[c]=np;
if(!p) return(void)(a[np].fa=1);
int q=a[p].s[c];
if(a[q].len==a[p].len+1) return(void)(a[np].fa=q);
int nq=++cnt;a[nq]=a[q],a[nq].len=a[p].len+1,a[q].fa=a[np].fa=nq;
for(;p&&a[p].s[c]==q;p=a[p].fa) a[p].s[c]=nq;
}
inline long long dfs(int x)
{//在SAM上跑dfs
if(ans[x]) return ans[x];
for(int i=0;i<26;i++) if(a[x].s[i]) ans[x]+=dfs(a[x].s[i])+1;
return ans[x];
}
int main()
{
scanf("%d%s",&n,s+1),memset(ans,0,sizeof(ans));
for(int i=1;i<=n;i++) ins(s[i]-'a');
return printf("%lld\n",dfs(1)),0;
}