题解 P5394 【【模板】下降幂多项式乘法】

· · 题解

我们已知 n 次多项式 f(x)[0,n] 的点值,求出它的下降幂表示。

f(x)=\sum\limits_{i=0}^{n}b_ix^{\underline{i}}=\sum\limits_{i=0}^{n}b_i\dfrac{x!}{(x-i)!}

\dfrac{f(x)}{x!}=\sum\limits_{i=0}^{n}b_i\dfrac{1}{(x-i)!}=b*e^x

因此,b=\text{EGF}(f)*e^{-x}

同理,下降幂表示转成点值也一样简单,卷上 e^x 即可。

因此,对于本题而言,我们先将下降幂多项式转为点值,然后直接相乘(注意转化的点值带有 EGF 的系数,乘的时候要消去),然后转回来即可。

注意,因为卷积会导致次数增大,e^xe^{-x} 的次数需要开到 2\max(n,m)

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353,g=3,invg=(mod+1)/g; 
int jc[200005]={1},ny[200005]={1},n,m,N,a[600005],b[600005],c[600005],d[600005],e[600005],tr[600005];
int C(int x,int y){
    if(x<y)return 0;
    return 1ll*jc[x]*ny[y]%mod*ny[x-y]%mod;
}
int Power(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=1ll*ret*x%mod;
        y>>=1,x=1ll*x*x%mod;
    }
    return ret;
}
void NTT(int a[],int len,int flag){
    for(int i=0;i<len;i++)if(tr[i]<i)swap(a[i],a[tr[i]]);
    for(int i=1;i<len;i<<=1){
        int w=Power(flag==1?g:invg,(mod-1)/(i<<1));
        for(int j=0;j<len;j+=(i<<1)){
            int u=1;
            for(int k=0;k<i;k++,u=1ll*u*w%mod){
                int x=a[j+k],y=a[i+j+k];
                a[j+k]=(x+1ll*u*y)%mod,a[i+j+k]=(x-1ll*u*y%mod+mod)%mod;
            }
        }
    }
    if(flag==-1)for(int i=0,inv=Power(len,mod-2);i<len;i++)a[i]=1ll*a[i]*inv%mod; 
}
int main() {
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%d",&c[i]);
    for(int i=0;i<=m;i++)scanf("%d",&d[i]);
    N=max(n,m)*2;
    for(int i=1;i<=N;i++)jc[i]=1ll*jc[i-1]*i%mod,ny[i]=Power(jc[i],mod-2);
    for(int i=0;i<=N;i++){
        if(i&1)b[i]=mod-ny[i];
        else b[i]=ny[i];
        a[i]=ny[i];
    }
    int len=1;
    while(len<=2*N)len<<=1;
    for(int i=0;i<len;i++)tr[i]=(tr[i>>1]>>1)|((i&1)?(len>>1):0);
    NTT(a,len,1),NTT(b,len,1),NTT(c,len,1),NTT(d,len,1);
    for(int i=0;i<len;i++)c[i]=1ll*a[i]*c[i]%mod,d[i]=1ll*d[i]*a[i]%mod;
    NTT(c,len,-1),NTT(d,len,-1);
    for(int i=0;i<=N;i++)e[i]=1ll*c[i]*d[i]%mod*jc[i]%mod;
    NTT(e,len,1);
    for(int i=0;i<len;i++)e[i]=1ll*e[i]*b[i]%mod;
    NTT(e,len,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",e[i]);
    return 0;
}