题解 P2408 【不同子串个数】

· · 题解

不用后缀数组的方法

哈希+LCP

首先思路还是把所有的后缀排序,然后减去每一个LCP的值
但后缀排序的结果可以用一种比较暴力的方法,用二分哈希找出当前比较的两个后缀的LCP,然后直接比较LCP的后一个字母的大小,时间复杂度O(n\log^2n )。(还是比后缀排序慢一点)
然后的方法就是\frac{n(n+1)}{2}-\sum(LCP)
证明:所有后缀的前缀,囊括了所有的子串,去除重复的即可。而排出后缀排序结果相邻的LCP,即重复部分就可得到答案。复杂度 O(n\log n )

总的复杂度O(n\log^2n ) 虽然比正统的后缀排序多一个\log,但本题的数据范围可以过。

Tip:在本题的数据中,取哈希的Base=131,Mod=1e9+7会被卡掉。

代码不是很长但挺好写(cmp和求LCP几乎一样)。

#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read()
{
    char ch=getchar(); int nega=1; while(!isdigit(ch)) {if(ch=='-') nega=-1; ch=getchar();}
    int ans=0; while(isdigit(ch)) {ans=ans*10+ch-48;ch=getchar();}
    if(nega==-1) return -ans; return ans;
}
#define N 100005
#define Mod 1000000007
#define Base 233
int n;
int s[N];
char a[N];
int t[N];
int mul[N];
int gethash(int b,int l)
{
    int dat=s[b+l]-s[b-1]*mul[l+2]%Mod;
    dat=(dat+Mod)%Mod;
    return dat;
}
bool cmp(int x,int y)
{
    if(x>y) return !cmp(y,x);
    if(x==y) return false;
    if(a[x]!=a[y]) return a[x]<a[y];
    int l=0,r=n-max(x,y)+1;
    int ans=0;
    while(l<=r)
    {
        int mid=(l+r)>>1;
        int hx=gethash(x,mid);
        int hy=gethash(y,mid);
        if(hx==hy)
        {
            l=mid+1,ans=mid;
        }
        else r=mid-1;
    }
    if(ans==n-max(x,y)+1) return true;
    return a[x+ans+1]<a[y+ans+1];
}
int LCP(int x,int y)
{
    if(x>y) return LCP(y,x);
    if(a[x]!=a[y]) return 0;
    int l=0,r=n-max(x,y);
    int ans=0;
    while(l<=r)
    {
        int mid=(l+r)>>1;
        int hx=gethash(x,mid);
        int hy=gethash(y,mid);
        if(hx==hy)
        {
            l=mid+1,ans=mid;
        }
        else r=mid-1;
    }
    return ans+1;
}
signed main()
{
#ifndef ONLINE_JUDGE
    freopen("in.txt","r",stdin);
#endif
    n=read();
    scanf("%s",a+1);
    mul[1]=1;
    for(int i=2;i<=n+3;i++)
    {
        mul[i]=mul[i-1]*Base%Mod;
    }
    for(int i=1;i<=n;i++)
    {
        s[i]=(s[i-1]*Base+(a[i]-'A'+1))%Mod;
    }
    for(int i=1;i<=n;i++) t[i]=i;
    sort(t+1,t+n+1,cmp);
    int ans=(n+1)*n/2;
    for(int i=2;i<=n;i++)
    {
        ans-=LCP(t[i-1],t[i]);
    }
    cout<<ans<<endl;
    return 0;
}