题解:P4721 【模板】分治 FFT

· · 题解

介绍 cdq+FFT 做法,跑的比较慢。

我们发现题目的式子其实是关于 f 的一个递推式。即知道 f_0\sim f_{i-1} 可以推出 f_i。考虑 cdq 分治。

假设我们已经求出 f_l\sim f_{mid},考虑这些项贡献到 f_{mid+1}\sim f_r。我们知道有 f_{i+j}=f_i\cdot g_j。其中 i\in [l,mid],j\in [1,r-l]。我们用 tmp1_0\sim tmp1_{mid-l} 代替 f_l\sim f_{mid},用 tmp2_1\sim tmp2_{r-l} 代替 g_1\sim g_{r-l}

考虑卷积后每一位应该所在的下标。f_{mid+1}\leftarrow f_{mid}\cdot g_1,而在新数组中 tmp_{mid-l+1}\leftarrow tmp1_{mid-l}\cdot tmp2_1。因此有 f_{i}\leftarrow tmp_{i-l}

想必写这题板子都写过了吧,就不说 NTT 了。

code:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int mod=998244353;
int n1,n2,n,k,rev[300000];
ll a[300000],b[300000],tmp[300000],tmp2[300000];
ll inv1[300005],inv2[300005];
ll power(ll x,ll k){
    if(k==0) return 1;
    if(k==1) return x;
    if(k%2){
        ll y=power(x,k/2);
        return y*y%mod*x%mod;
    }
    else{
        ll y=power(x,k/2);
        return y*y%mod;
    }
}
void init(){
    for(int i=0;i<n;i++){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<k-1);
    }
}
void NTT(ll *a,int n,int op){
    for(int i=0;i<n;i++){
        if(i<rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int s=1;s<n;s<<=1){
        ll wn=(op==1? power(3,(mod-1)/(s<<1)):power(power(3,mod-2),(mod-1)/(s<<1)));
        for(int i=0;i<n;i+=s<<1){
            ll w=1;
            for(int j=i;j<i+s;j++,w=w*wn%mod){
                ll x=a[j],y=w*a[j+s]%mod;
                a[j]=(x+y)%mod;
                a[j+s]=(x-y+mod)%mod;
            }
        }
    }
}
void cdq(int l,int r){
    if(l==r) return;
    int mid=l+r>>1;
    cdq(l,mid);
    n=2,k=1;
    while(n<mid-l+1+r-l){
        n<<=1;
        k++;
    }
    init();
    for(int i=0;i<n;i++){
        tmp[i]=0;//注意清空
        tmp2[i]=b[i];
    }
    for(int i=l;i<=mid;i++){
        tmp[i-l]=a[i];
    }
    NTT(tmp,n,1);
    NTT(tmp2,n,1);
    for(int i=0;i<n;i++){
        tmp[i]=tmp[i]*tmp2[i]%mod;
    }
    NTT(tmp,n,-1);
    ll invn=power(n,mod-2);
    for(int i=mid+1;i<=r;i++){
        a[i]=(a[i]+tmp[i-l]*invn%mod)%mod;
    }
    cdq(mid+1,r);
}
int main(){
    cin>>n1;
    for(int i=1;i<n1;i++){
        scanf("%lld",&b[i]);
    }
    a[0]=1;
    cdq(0,n1-1);
    for(int i=0;i<n1;i++){
        printf("%d ",a[i]);
    }
    return 0;
}