谷歌翻(sheng)译(cao)机 加强版 题解

· · 题解

以下题解内容由@w6666提供,因为他懒得开博客所以就被放在我的博客里了

f[i][j]表示在A的前i个与B的前j个中有多少相同的子序列(包含空串)

考虑$A$中的区间$[i,k]$在所有分割串的方案中出现的次数。 那就需要在$A$中的子序列$P$跨过$[i,k]$这个区间,必须要有$i-1$和$k+1$这两个位置,且其他的位置在$[1,i-2]$或$[k+2,n]$之间。 枚举在$B$中与$P$匹配的子序列$Q$,设$P$的$i-1$这个位置与$Q$在$j$匹配, 那就有式子$f[i-2][j-1]\times(g[k+1][j+1]-g[k+2][j+1])$,且$A[i-1]=B[j]$。 故在$A$中总答案就是$\sum\limits^{n}_ {i=1} \sum\limits^{n}_ {k=i} \sum\limits^{m+1}_ {A[i-1]=B[j]}f[i-2][j-1]\times(g[k+1][j+1]-g[k+2][j+1])\times(k-i+1)^t

时间复杂度O(n^3)

j是固定的,构造多项式G[x]=g[x+1][j+1]-g[x+2][j+1]

V[x]=x^t

原式变为\sum\limits^{m+1}_ {j=0}\sum\limits_ {A[i-1]=B[j]} f[i-2][j-1] \sum\limits^{n}_ {k=i} G[k]\times V[k-i+1]

只要把G翻转然后用NTTGV做卷积,对于每个i取卷积后的第n-i+1项系数即可。

时间复杂度O(n^2logn)应该不卡常吧,std开O2跑500ms

(std中枚举的是区间[i+1,k],会有所不同)

#include<bits/stdc++.h>
#define N 1010
#define MOD 998244353
#define ll long long
using namespace std;
int n,m,t,ans;
int a[N<<2],v[N<<2],f[N][N],g[N][N];
//v[i]=i^t
//f[i][j]表示前(i,j)里相同子序列的方案数
//g[i][j]表示后(i,j)里相同子序列的方案数
char s1[N],s2[N];
//视s1[0]=s2[0],s1[n+1]=s2[m+1]
int rev[N<<2];
int len,bit;
//ntt用 
inline int ksm(int x,int k)
{
    int res=1;
    while(k)
    {
        if(k&1)res=1ll*res*x%MOD;
        x=1ll*x*x%MOD;
        k>>=1; 
    }
    return res;
}
inline int qm(ll x)
{
    return (x%MOD+MOD)%MOD;
}
inline void ntt(int len,int a[],int inv)
{
    for(int i=1;i<len;i++)
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    for(int i=1;i<len;i<<=1)
    {
        int gi,up=i<<1;
        if(inv==1)gi=ksm(3ll,(MOD-1)/up);
        else gi=ksm(1ll*(MOD+1)/3,(MOD-1)/up);
        for(int j=0;j<len;j+=up)
        {
            int gx=1;
            for(int k=0;k<i;k++)
            {
                int x=a[j+k],y=1ll*gx*a[i+j+k]%MOD;
                a[j+k]=qm(x+y),a[i+j+k]=qm(x-y);
                gx=1ll*gx*gi%MOD;
            }
        }
    }
    if(inv==-1)
    {
        int invx=ksm(len,MOD-2);
        for(int i=0;i<len;i++)
            a[i]=1ll*a[i]*invx%MOD;
    }
}
inline void getmul(int lenx,int a[])
{
    ntt(len,a,1);
    for(int i=0;i<len;i++)
        a[i]=1ll*a[i]*v[i]%MOD;
    ntt(len,a,-1);
    for(int i=lenx;i<len;i++)
        a[i]=0;
}
int main()
{
    scanf("%d%d%d",&n,&m,&t);
    scanf("%s",s1);
    scanf("%s",s2);
    for(int i=n;i>=1;i--)s1[i]=s1[i-1];
    for(int i=m;i>=1;i--)s2[i]=s2[i-1];
    for(int i=max(n,m);i>=1;i--)
        v[i]=ksm(i,t);
    len=1,bit=-1;
    while(len<2*max(n,m)+1)len<<=1,++bit;
    for(int i=1;i<len;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<bit);
    ntt(len,v,1);//因为都是与V做卷积,直接预处理 ,当然不这样也能过
    s1[0]=s2[0]=' ';
    for(int i=0;i<=n;i++)f[i][0]=1;
    for(int i=0;i<=m;i++)f[0][i]=1;
    for(int i=1;i<=n+1;i++)g[i][m+1]=1;
    for(int i=1;i<=m+1;i++)g[n+1][i]=1;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=m;j++)
        {
            f[i][j]=qm(f[i-1][j]+f[i][j-1]-f[i-1][j-1]);
            if(s1[i]==s2[j])
                f[i][j]=qm(f[i][j]+f[i-1][j-1]);
        }
    for(int i=n;i>=1;i--)
        for(int j=m;j>=1;j--)
        {
            g[i][j]=qm(g[i+1][j]+g[i][j+1]-g[i+1][j+1]);
            if(s1[i]==s2[j])
                g[i][j]=qm(g[i][j]+g[i+1][j+1]);
        }
    //计算A中答案 
    for(int j=1;j<=m;j++)
    {
        a[n]=0;
        for(int k=1;k<=n;k++)
            a[n-k]=qm(g[k+1][j+1]-g[k+2][j+1]);//构造多项式 
        getmul(n,a);//计算卷积 
        for(int i=1;i<=n;i++)
            if(s1[i]==s2[j])//此处枚举的是区间[i+1,k] 
                ans=qm(ans+1ll*a[n-i]*f[i-1][j-1]%MOD);
    }
    for(int i=1;i<=n;i++)a[i]=0;
    //计算B中答案 
    for(int i=1;i<=n;i++)
    {
        a[m]=0;
        for(int k=1;k<=m;k++)
            a[m-k]=qm(g[i+1][k+1]-g[i+1][k+2]);
        getmul(m,a);
        for(int j=1;j<=m;j++)
            if(s1[i]==s2[j])
                ans=qm(ans+1ll*a[m-j]*f[i-1][j-1]%MOD);
    }
    for(int i=max(n,m);i>=1;i--)
        v[i]=ksm(i,t);
    for(int k=1;k<=n;k++)
        ans=qm(ans+1ll*qm(1ll*qm(g[k+1][1]-g[k+2][1])*v[k]%MOD));//f[-1][j-1]特殊处理 
    for(int k=1;k<=m;k++)
        ans=qm(ans+1ll*qm(1ll*qm(g[1][k+1]-g[1][k+2])*v[k]%MOD));
    printf("%d",ans);
}