题解:P5404 [CTS2019] 重复

· · 题解

下文 \operatorname{MSET} 同无标号无根树计数使用的 \operatorname{MSET} 变换。两边取 \ln 即可在 O(n\log n) 内时间计算其逆。

限制只需考虑所需统计字符串 t 的最小表示法。若希望 t 本身是其最小表示法,则 t 是 lyndon 串的若干重复。如果 t=t'^{k},则 t' 对答案贡献为 m/k。因此下面只考虑 t 本身是 lyndon 串(记为 t\in L)的情况。

如果没有字典序小于 s 的限制:由于每个字符串有唯一的 lyndon 分解 S=T_{1:m},满足 T_{i}\ge T_{i+1},T_j\in L,因此设所有字符串的生成函数是 S,有:

\operatorname{MSET}(L)=S

逆变换 MSET 即可求出。

加入限制后,希望保留 \operatorname{MSET} 的统计法。此时,仅保留 s 分解的第一个 lyndon 串 s_1 并重复之使得长度为 m,记为 s_1'。此时取所有字典序 \le s_1' 的 lyndon 串集合 L'\operatorname{MSET}(L) 即为字典序 <s 的字符串集合带上 s_1

原因是考虑两字符串 s,t 的比较。如果 |s_1|> |t_1|,由于 t_2\le t_1s_1\in L,若 t_1s_1 的前缀,t_{2:m} 必能决定 t<s

否则,|s_1|\le |t_1|,这只会在 s_1t_1 的前缀时产生疑问。但由于 t_1\le s_1',则 t_1|s_1| 后的位置仍然 \le s_1,而此时 t_1 开头就并非最小后缀(不如去掉 s_1 后的后缀),这就产生了矛盾。

上面是小于等于。但是由于 lyndon 串不存在小于长度的周期,实际上取等只在 s_1。这就说明了正确性。

瓶颈在逆 MSET 变换,时间复杂度 O(m\log m)

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=1e5+5,mod=998244353,g=3,ig=(mod+1)/3;
int lim,L,r[maxn];
void predo(int n){
    lim=1,L=0;
    while(lim<=n)lim<<=1,L++;
    for(int i=1;i<lim;i++)r[i]=(r[i>>1]>>1)|((i&1)<<L-1);
}
#define Poly vector<int> 
int qp(int a,int b){
    if(b==0)return 1;
    int T=qp(a,b>>1);T=T*T%mod;
    if(b&1)return T*a%mod;
    return T;
}
void ntt(Poly &a,int tp){
    for(int i=0;i<lim;i++)if(r[i]>i)swap(a[i],a[r[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int wn=qp(tp==1?g:ig,(mod-1)/(mid<<1));
        for(int j=0;j<lim;j+=(mid<<1)){
            int W=1;
            for(int k=0;k<mid;k++,W=W*wn%mod){
                int x=a[j+k],y=a[j+k+mid]*W%mod;
                a[j+k]=(x+y)%mod,a[j+k+mid]=(x-y+mod)%mod;
            }
        }
    }
    if(tp==-1){
        int I=qp(lim,mod-2);
        for(int i=0;i<lim;i++)a[i]=a[i]*I%mod;
    }
}
Poly operator +(Poly a,Poly b){
    Poly c(max(a.size(),b.size()));
    for(int i=0;i<c.size();i++){
        c[i]=0;
        if(a.size()>i)c[i]+=a[i];
        if(b.size()>i)c[i]+=b[i];
        c[i]%=mod;
    }
    return c;
}
Poly operator -(Poly a,Poly b){
    Poly c(max(a.size(),b.size()));
    for(int i=0;i<c.size();i++){
        c[i]=0;
        if(a.size()>i)c[i]+=a[i];
        if(b.size()>i)c[i]-=b[i]-mod;
        c[i]%=mod;
    }
    return c;
}
Poly operator *(Poly a,Poly b){
    Poly A=a,B=b,c;
    predo(a.size()+b.size());
    A.resize(lim),B.resize(lim);c.resize(lim);
    ntt(A,1);ntt(B,1);
    for(int i=0;i<lim;i++)c[i]=A[i]*B[i]%mod;
    ntt(c,-1);
    c.resize(a.size()+b.size()-1);
    return c;
}
Poly inv(const Poly &a,int len){
    if(len==1){
        Poly b(1);
        b[0]=qp(a[0],mod-2);
        return b;
    }
    Poly b=inv(a,(len+1)>>1),c(len);
    for(int i=0;i<len;i++)c[i]=a[i];
    predo(len*2-1);
    b.resize(lim),c.resize(lim);
    ntt(c,1);ntt(b,1);
    for(int i=0;i<lim;i++)b[i]=(2-b[i]*c[i]%mod+mod)%mod*b[i]%mod;
    ntt(b,-1);b.resize(len);
    return b;
}
Poly ln(const Poly &a,int len){
    Poly b=inv(a,len),F(len-1);
    for(int i=0;i<len-1;i++)F[i]=(i+1)*a[i+1]%mod;
    b=b*F;b.resize(len);
    for(int i=len-1;i>0;i--)b[i]=b[i-1]*qp(i,mod-2)%mod;
    b[0]=0;
    return b;
}
Poly deriv(Poly a){
    if(a.size()==1)return (Poly){0};
    Poly b;b.resize(a.size()-1);
    for(int i=1;i<a.size();i++)b[i]=i*a[i]%mod;
    return b;
}
Poly integ(Poly a){
    Poly b;b.resize(a.size()+1);
    for(int i=0;i<a.size();i++)b[i+1]=a[i]*qp(i+1,mod-2)%mod;
    return b;
}
Poly exp(const Poly &a,int len){
    if(len==1)return (Poly){1};
    Poly b=exp(a,(len+1)>>1);b.resize(len);Poly lnb=ln(b,len);
    for(int i=0;i<len;i++)lnb[i]=(a[i]-lnb[i]+mod)%mod;
    (lnb[0]+=1)%=mod;
    b=b*lnb;b.resize(len);
    return b;
}
vector<int> fac[maxn];
Poly NMSET(Poly A){
    int n=A.size();
    Poly B=ln(A,n),res(n);
    for(int i=1;i<n;i++){
        res[i]=B[i];
        for(auto j:fac[i])if(j<i)(res[i]+=mod-res[j]*qp(i/j,mod-2)%mod)%=mod;
    }
    return res;
}
string t;char s[maxn],S[maxn];
int n,dL[maxn],dR[maxn],cnt=0,m;
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    for(int i=1;i<=2000;i++)for(int j=i;j<=2000;j+=i)fac[j].push_back(i);
    cin>>m>>t;n=t.size();
    for(int i=1;i<=n;i++)s[i]=t[i-1];
    for(int i=1;i<=n;){
        int j=i,k=i+1;
        while(k<=n&&s[j]<=s[k]){
            if(s[j]==s[k])++j;
            else j=i;
            ++k;
        }
        while(i<=j){
            dL[++cnt]=i,dR[cnt]=i+(k-j)-1;
            i+=k-j;
        }
    }
    for(int l=1,r;l<=m;l=r+1){
        r=l+dR[1]-1;int cr=min(r,m);
        for(int j=l;j<=r;j++)S[j]=s[j-l+1];
    }int ans=0;
    int I=qp(26,mod-2),z=1,cr=0;
    Poly F(m+1);
    for(int i=1;i<=m;i++){
        z=z*I%mod;(cr+=(S[i]-'a')*z)%=mod;
        F[i]=qp(26,i)*cr%mod;
    }
    for(int i=0;i<=m;i++)F[i]++;
    Poly G=NMSET(F);
    if(dR[1]<=m)G[dR[1]]--;
    for(auto x:fac[m])(ans+=G[x]*x)%=mod;
    cout<<ans<<endl;
    return 0;
}