浅谈FFT与NTT在字符串匹配中的应用

· · 算法·理论

FFT与NTT

FFT 与 NTT 常用于在 O(n\log n) 的时间处理多项式乘法。

根据题目具体数据范围选择使用 FFT 还是 NTT。

含通用符的字符串匹配问题

在无通用符的字符串匹配问题,常用时间为 O(n+m) 的 KMP 算法。

当然,也可以设匹配函数 P(x)=\sum\limits_{i=0}^{m-1} (A_i-B_{x+i})

但是 A_i-B_{x+i} 并没有正负性,所以要将其平方,P(x)=\sum\limits_{i=0}^{m-1} (A_i-B_{x+i})^2=\sum\limits_{i=0}^{m-1} (A_i^2-2A_iB_{x+i}+B_{x+i}^2)

显然 A_i^2B_{x+i}^2 可以预处理出来,但是 A_iB_{x+i} 并不好处理。

一般对于这种式子需要改写成卷积的形式,设匹配函数 P(x)=\sum\limits_{i=0}^{m-1} (A_{m-i-1}-B_{x-m+i+1})^2

需要将 A 翻转以匹配这个函数,将这个函数拆开来,发现 P(x)=\sum\limits_{i=0}^{m-1} (A_{m-i-1}^2-2A_{m-i-1}B_{x-m+i+1}+B_{x-m+i+1}^2)

发现 A_{m-i-1}B_{x-m+i+1} 就是卷积的形式,便可用 FFT 直接求出答案,时间复杂度 O(n\log n)

P4173 残缺的字符串

这题出现了可以代替一个字符的通用符,可以将通用符的权值设为 0,再乘上权值即可。

P(x)=\sum\limits_{i=0}^{m-1} (A_{m-i-1}-B_{x-m+i+1})^2A_{m-i-1}B_{x-m+i+1}=\sum\limits_{i=0}^{m-1} (A_{m-i-1}^3B_{x-m+i+1}-2A_{m-i-1}^2B_{x-m+i+1}^2+A_{m-i-1}B_{x-m+i+1}^3)

只需要做三次 FFT 即可,时间复杂度 O(n\log n)

#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef complex<double> cp;
const double pi=acos(-1);
const int N=2e6+5;
int n,m,lim=1,rev[N],f[N],g[N],p[N],ans=0;
cp a[N],b[N];
char A[N],B[N]; 
void fft(cp *a,int flag){
    for(int i=0;i<lim;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        cp wn(cos(pi/mid),flag*sin(pi/mid));
        for(int i=mid*2,j=0;j<lim;j+=i){
            cp w(1,0);
            for(int k=0;k<mid;k++,w*=wn){
                cp x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y,a[j+mid+k]=x-y;
            }
        }
    }
}
void solve(int opt){
    fft(a,1),fft(b,1);
    for(int i=0;i<lim;i++)a[i]*=b[i];
    fft(a,-1);
    for(int i=0;i<n+m;i++)p[i]+=(int)(a[i].real()/lim+0.5)*opt;
    memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
//  for(int i=0;i<n+m;i++)cout <<p[i]<<" ";cout <<endl;
}
signed main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin >>m>>n>>A>>B;
    for(int i=0;i<m/2;i++)swap(A[i],A[m-i-1]);
    for(int i=0;i<m;i++)f[i]=(A[i]=='*'?0:A[i]-'a'+1);
    for(int i=0;i<n;i++)g[i]=(B[i]=='*'?0:B[i]-'a'+1);
    int k=0;
    while(lim<=n+m)lim*=2,k++;
    for(int i=0;i<lim;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));

    for(int i=0;i<m;i++)a[i]=f[i]*f[i]*f[i];
    for(int i=0;i<n;i++)b[i]=g[i];
    solve(1);
    for(int i=0;i<m;i++)a[i]=f[i]*f[i];
    for(int i=0;i<n;i++)b[i]=g[i]*g[i];
    solve(-2);
    for(int i=0;i<m;i++)a[i]=f[i];
    for(int i=0;i<n;i++)b[i]=g[i]*g[i]*g[i];
    solve(1);
    for(int x=m-1;x<=n-1;x++)if(p[x]==0)ans++;
    cout <<ans<<"\n";
    for(int x=m-1;x<=n-1;x++)if(p[x]==0)cout <<x-m+2<<" ";
    return 0;
}

下面的是使用 NTT 的代码。

#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef complex<double> cp;
const double pi=acos(-1);
const int N=2097152,P=998244353,G=3,I=332748118;
int n,m,lim=1,rev[N],f[N],g[N],p[N],ans=0,a[N],b[N];
char A[N],B[N]; 
int qpow(int a,int b){
    int res=1;
    while(b){
        if(b&1)res=res*a%P;
        a=a*a%P,b>>=1;
    }
    return res;
}
int inv(int a){
    return qpow(a,P-2);
}
void ntt(int *a,int flag){
    for(int i=0;i<lim;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int wn=qpow(flag?G:I,(P-1)/(mid<<1));
        for(int i=mid*2,j=0;j<lim;j+=i){
            int w=1;
            for(int k=0;k<mid;k++,w=w*wn%P){
                int x=a[j+k],y=w*a[j+mid+k]%P;
                a[j+k]=(x+y+P)%P,a[j+mid+k]=(x-y+P)%P;
            }
        }
    }
}
void solve(int opt){
    ntt(a,1),ntt(b,1);
    for(int i=0;i<lim;i++)a[i]=a[i]*b[i]%P;
    ntt(a,0);
    for(int i=0,s=inv(lim);i<n+m;i++)p[i]+=(a[i]*s%P)*opt;
    memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
}
signed main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin >>m>>n>>A>>B;
    for(int i=0;i<m/2;i++)swap(A[i],A[m-i-1]);
    for(int i=0;i<m;i++)f[i]=(A[i]=='*'?0:A[i]-'a'+1);
    for(int i=0;i<n;i++)g[i]=(B[i]=='*'?0:B[i]-'a'+1);
    int k=0;
    while(lim<=n+m)lim*=2,k++;
    for(int i=0;i<lim;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));

    for(int i=0;i<m;i++)a[i]=f[i]*f[i]*f[i];
    for(int i=0;i<n;i++)b[i]=g[i];
    solve(1);
    for(int i=0;i<m;i++)a[i]=f[i]*f[i];
    for(int i=0;i<n;i++)b[i]=g[i]*g[i];
    solve(-2);
    for(int i=0;i<m;i++)a[i]=f[i];
    for(int i=0;i<n;i++)b[i]=g[i]*g[i]*g[i];
    solve(1);
    for(int x=m-1;x<=n-1;x++)if(p[x]==0)ans++;
    cout <<ans<<"\n";
    for(int x=m-1;x<=n-1;x++)if(p[x]==0)cout <<x-m+2<<" ";
    return 0;
}

CF1975G Zimpha Fan Club

本题出现了能代替字符串的通用符。

显然,ST 都有这种通用符的情况是简单的。

S 有通用符,并先处理前后缀,那么 S 的形式应该是 *s_1*s_2*\dots*s_k*

s_i 按顺序全部匹配 T 才算成功匹配,直接卷积时间复杂度为 O(nk\log n)

可以发现,NTT 直接将 s_iT 中所有能匹配的位置都求出来了,但我们只需要最前面的那一个。

可以将 T 中长度为 2\left | s_i\right | 的前缀拿出进行匹配。如果能匹配成功,将匹配的部分即前面的部分从 T 中删去;若果不能匹配,那 T 种长度为 \left | s_i\right | 的前缀就没用了,也删去。

发现一次匹配最少能删去 T 中长度为 \left | s_i\right | 的前缀,时间复杂度为 O(\left | s_i\right |\log \left | s_i\right |)

总时间复杂度为 O(n\log n)

抽象出题人开 n=2\times 10^6,所以不能用 FFT,考验 NTT 板子和字符串处理的速度。

注意 NTT 模数为 998244353 时会被卡,要换一个。

#include <bits/stdc++.h>
using namespace std;
const int N=4194305,P=469762049,G=3,I=156587350;
string s,t;
int ss[N],tt[N],top=0,tot=0;
struct node{
    int n,m,lim=1,rev[N],p[N],a[N],b[N],f[N],g[N];
    int qpow(int a,int b){
        int res=1;
        while(b){
            if(b&1)res=1ll*res*a%P;
            a=1ll*a*a%P,b>>=1;
        }
        return res;
    }
    int inv(int a){
        return qpow(a,P-2);
    }
    void ntt(int *a,int flag){
        for(int i=0;i<lim;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
        for(int mid=1;mid<lim;mid<<=1){
            int wn=qpow(flag?G:I,(P-1)/(mid<<1));
            for(int i=mid*2,j=0;j<lim;j+=i){
                int w=1;
                for(int k=0;k<mid;k++,w=1ll*w*wn%P){
                    int x=a[j+k],y=1ll*w*a[j+mid+k]%P;
                    a[j+k]=(0ll+x+y+P)%P,a[j+mid+k]=(0ll+x-y+P)%P;
                }
            }
        }
    }
    void solve(int opt){
        ntt(a,1),ntt(b,1);
        for(int i=0;i<lim;i++)a[i]=1ll*a[i]*b[i]%P;
        ntt(a,0);
        for(int i=0,s=inv(lim);i<n+m;i++)p[i]+=(1ll*a[i]*s%P)*opt%P,p[i]=(p[i]+P)%P;
        for(int i=0;i<lim;i++)a[i]=b[i]=0;
    }
    int init(int l1,int r1,int l2,int r2){
        m=r1-l1+1,n=r2-l2+1;
        for(int i=0;i<m;i++)f[i]=ss[l1+i];
        for(int i=0;i<n;i++)g[i]=tt[l2+i];
        for(int i=0;i<m/2;i++)swap(f[i],f[m-i-1]);
        int k=0;lim=1;
        while(lim<=n+m)lim<<=1,k++;
        for(int i=0;i<lim;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
        for(int i=0;i<m;i++)a[i]=f[i]*f[i]*f[i];
        for(int i=0;i<n;i++)b[i]=g[i];
        solve(1);
        for(int i=0;i<m;i++)a[i]=f[i]*f[i];
        for(int i=0;i<n;i++)b[i]=g[i]*g[i];
        solve(-2);
        for(int i=0;i<m;i++)a[i]=f[i];
        for(int i=0;i<n;i++)b[i]=g[i]*g[i]*g[i];
        solve(1);
        for(int i=0;i<lim;i++)f[i]=g[i]=rev[i]=0;
        int res=-1;
        for(int x=m-1;x<=n-1;x++)if(p[x]==0){
            res=x-m+1;
            break;
        }
        for(int i=0;i<lim;i++)p[i]=0;
        return res;
    }
}NTT;
int work(int l,int r){
    int len=r-l+1,res;
    if(tot-top+1<len)return 0;
    if(tot-top+1<=len*2)res=NTT.init(l,r,top,tot);
    else res=NTT.init(l,r,top,top+len*2-1);
    if(res==-1){
        top+=len;
        return -1;
    }
    top+=res+len;
    return 1;
}
int n,m;
signed main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin >>n>>m>>s>>t;
    int p=0,q=0;
    for(int i=0;i<n;i++)p+=(s[i]=='*');
    for(int i=0;i<m;i++)q+=(t[i]=='*');
    if(p&&q){
        while(s[n-1]!='*'&&t[m-1]!='*'){
            if(s[n-1]!='-'&&t[m-1]!='-'&&s[n-1]!=t[m-1]){
                cout <<"No";
                return 0;
            }
            n--,m--;
        }
        for(int i=0;s[i]!='*'&&t[i]!='*';i++){
            if(s[i]!='-'&&t[i]!='-'&&s[i]!=t[i]){
                cout <<"No";
                return 0;
            }
        }
        cout <<"Yes";
        return 0;
    }
    if((!p)&&(!q)){
        if(n!=m){
            cout <<"No";
            return 0;
        }
        for(int i=0;i<n;i++){
            if(s[i]!='-'&&t[i]!='-'&&s[i]!=t[i]){
                cout <<"No";
                return 0;
            }
        }
        cout <<"Yes";
        return 0;
    }
    if(!p)swap(n,m),swap(s,t);
    while(s[n-1]!='*'&&t[m-1]!='*'&&n&&m){
        if(s[n-1]!='-'&&t[m-1]!='-'&&s[n-1]!=t[m-1]){
            cout <<"No";
            return 0;
        }
        n--,m--;
    }
    while(s[0]!='*'&&t[0]!='*'&&n&&m){
        if(s[0]!='-'&&t[0]!='-'&&s[0]!=t[0]){
            cout <<"No";
            return 0;
        }
        s.erase(0,1),t.erase(0,1);
        n--,m--;
    }
    if(m==0){
        for(int i=0;i<n;i++)if(s[i]!='*'){
            cout <<"No";
            return 0;
        }
        cout <<"Yes";
        return 0;
    }
    s.erase(n),t.erase(m);
    while(s[0]=='*')s.erase(0,1),n--;
    for(int i=0;i<n;i++)ss[i]=(s[i]=='-'?0:s[i]-'a'+1);
    for(int i=0;i<m;i++)tt[i]=(t[i]=='-'?0:t[i]-'a'+1);
    tot=m-1;
    for(int i=0,st=0;i<n;i++){
        if(s[i]=='*'){
            if(s[i+1]!='*')st=i+1;
            continue;
        }
        if(s[i+1]!='*')continue;
        while(true){
            int res=work(st,i);
            if(res==0){
                cout <<"No";
                return 0;
            }else if(res==1)break;
        }
    }
    cout <<"Yes";
    return 0;
}

关于位置对称的字符串匹配问题

P4199 万径人踪灭

f_i=\sum\limits_{j=0}^i [s_j=s_{2\times i-j}]

如果不管不能是连续的一段的限制,那么每一个 i 的答案就是 2_{f_i}-1

是连续的一段的限制直接用 Manacher 做(其实也可以二分+哈希)。

发现 f_i=\sum\limits_{j=0}^i [s_j=s_{2\times i-j}] 是卷积形式。

a_i 表示 s_i 是否为 ab_i 表示 s_i 是否为 b

那么 f=a*a+b*b

FFT 直接做即可,时间复杂度 O(n\log n)

#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef complex<double> cp;
const int N=4e5+5;
const int mod=1e9+7;
int n,cnt,dp[N],len=0,f[N],g[N],ans=0;
cp a[N],b[N];
char ch[N],s[N];
int qpow(int a,int b){
    int sum=1;
    while(b){
        if(b&1)sum=sum*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return sum;
}
struct node{
    const double pi=acos(-1);
    int n,m,lim=1,rev[N];
    void fft(cp *a,int flag){
        for(int i=0;i<lim;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
        for(int mid=1;mid<lim;mid<<=1){
            cp wn(cos(pi/mid),flag*sin(pi/mid));
            for(int i=mid*2,j=0;j<lim;j+=i){
                cp w(1,0);
                for(int k=0;k<mid;k++,w*=wn){
                    cp x=a[j+k],y=w*a[j+mid+k];
                    a[j+k]=x+y,a[j+mid+k]=x-y;
                }
            }
        }
    }
    void solve(int n,int m,int *p){
        int k=0;lim=1;
        while(lim<=n+m)lim*=2,k++;
        for(int i=0;i<lim;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
        fft(a,1),fft(b,1);
        for(int i=0;i<lim;i++)a[i]*=b[i];
        fft(a,-1);
        for(int i=0;i<=n+m;i++)p[i]=(int)(a[i].real()/lim+0.5);
        memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
    }
}FFT;
signed main(){
    cin >>ch;
    int n=strlen(ch);
    s[++len]='<';s[++len]='#';
    for(int i=0;i<n;i++){
        s[++len]=ch[i];
        s[++len]='#';
    }
    s[++len]='>';
    for(int i=1,mid=0,r=0;i<len-1;i++){
        if(i<r)dp[i]=min(dp[2*mid-i],dp[mid]+mid-i);
        else dp[i]=1;
        while(s[i+dp[i]]==s[i-dp[i]])dp[i]++;
        if(i+dp[i]>r)r=i+dp[i],mid=i;
        ans=(ans-(dp[i]>>1)+mod)%mod;
    }
    for(int i=0;i<n;i++)a[i]=b[i]=(ch[i]=='a');
    FFT.solve(n,n,f);
    for(int i=0;i<n;i++)a[i]=b[i]=(ch[i]=='b');
    FFT.solve(n,n,g);
    for(int i=0;i<n+n;i++){
        f[i]=(f[i]+g[i]+((i&1)?0:1))>>1,ans=(ans+qpow(2,f[i])-1+mod)%mod;
    }
    cout <<ans;
    return 0;
}