P7409 题解

· · 题解

题解区竟然没有并查集的题解,这里来一个。

先对下标进行离散化,再依照套路先求出 SA 数组以及 rank 数组,并在此基础上求出 height 数组(\text{height}_i=\text{LCP}(\text{sa}_{i},\text{sa}_{i-1}))。

对于一个长度为 \text{len} 的子串,如果它在字符串中出现了 x 次,那么一定能找到一个 k,使得 \min \{ \text{height}_{k \dots k+x-2} \} \ge \text{len}

我们可以枚举子串的长度 \text{len},但是复杂度会爆炸。

考虑使用并查集。具体地,我们从大到小枚举 \text{len},对于 \text{height}_i=\text{len} 的情况,合并 ii-1 两个块。

AC Code

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=23333333333333333;
const int N=5e5+5;
int n,T,m,len;
int x[N],y[N],sa[N],rk[N],c[N],a[N],he[N],f[N][21],lg[21],h[N],fa[N];
char s[N];
ll sz[N];
vector<int> vec[N];
int get(int x)
{
    if(fa[x]==x) return x;
    return fa[x]=get(fa[x]);
}
void prework()
{
    lg[0]=1;
    for(int i=1;i<=20;i++) lg[i]=lg[i-1]<<1;
}
int ask(int l,int r)
{
    int k=(int)(log(r-l+1)/log(2));
    return min(f[l][k],f[r-(1<<k)+1][k]);
}
void SA()
{
    m=122;
    for(int i=1;i<=n;i++) c[x[i]=s[i]]++;
    for(int i=2;i<=m;i++) c[i]+=c[i-1];
    for(int i=n;i>=1;i--) sa[c[x[i]]--]=i;
    for(int k=1;k<=n;k<<=1)
    {
        int num=0;
        for(int i=n-k+1;i<=n;i++) y[++num]=i;
        for(int i=1;i<=n;i++) if(sa[i]>k) y[++num]=sa[i]-k;
        for(int i=1;i<=m;i++) c[i]=0;
        for(int i=1;i<=n;i++) c[x[i]]++;
        for(int i=2;i<=m;i++) c[i]+=c[i-1];
        for(int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0;
        swap(x,y),num=1,x[sa[1]]=1;
        for(int i=2;i<=n;i++)
        {
            if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k]) x[sa[i]]=num;
            else x[sa[i]]=++num;
        }
        if(num==n) break;
        m=num;
    }
    for(int i=1;i<=n;i++) rk[sa[i]]=i;
    int k=0;
    for(int i=1;i<=n;i++)
    {
        if(rk[i]>1)
        {
            if(k) k--;
            while(s[i+k]==s[sa[rk[i]-1]+k]) k++;
            he[rk[i]]=k;
        }
   }
   for(int i=1;i<=n;i++) f[rk[i]][0]=he[rk[i]];
   int t=(int)(log(n)/log(2))+1;
   for(int j=1;j<=t;j++)
   {
        for(int i=1;i<=n-lg[j]+1;i++) f[i][j]=min(f[i][j-1],f[i+lg[j-1]][j-1]);
   }
}
void solve()
{
    ll ans=0;
    for(int i=1;i<=len;i++) scanf("%d",&a[i]);
    sort(a+1,a+1+len,[](const int &x,const int &y){return rk[x]<rk[y];});
    len=unique(a+1,a+1+len)-a-1;
    for(int i=2;i<=len;i++) h[i]=ask(rk[a[i-1]]+1,rk[a[i]]),vec[h[i]].push_back(i);
    for(int i=1;i<=len;i++) fa[i]=i,sz[i]=1;
    for(int i=n-1;i>=0;i--)
    {
        for(auto j:vec[i])
        {
            int x=get(j),y=get(j-1);
            if(x==y) continue;
            ans=(ans+1ll*i*sz[x]*sz[y])%mod;
            sz[x]+=sz[y],fa[y]=x;
        }
    }
    printf("%lld\n",ans);
    for(int i=2;i<=len;i++) vec[h[i]].clear();
}
int main()
{
    scanf("%d%d%s",&n,&T,s+1); prework(); SA();
    while(T--) scanf("%d",&len),solve();
    return 0;   
}