题解 P4245 【【模板】任意模数NTT】

· · 题解

如果不取模,按 OI 常见数据范围来讲两个多项式相乘以后系数不超过 10^9\times10^9\times10^5=10^{23},如果我们算出该多项式分别 \bmod 3 个 NTT 模数下的值,我们就可以通过 CRT 合并计算出模三个模数之积的值,而这样最高精度能达到 10^{26} 级别,足够应付。但是这种方法一次多项式乘法需要 9 次 DFT,比较慢。下面介绍一下比较常见的优化拆系数 FFT。

众所周知,FFT 中不用取模,只需要结果取一次模就可以了。但是这题最大的系数达到了 10^{23} 级别,直接 FFT 精度不足。对此我们引出拆系数 FFT。

bas 为一个 \sqrt{p} 级别的数,F(x),G(x) 为我们要进行乘法的多项式,那么我们可以把 两个多项式表示成以下形式:

\begin{aligned} F(x)=basA(x)+B(x)\\ G(x)=basC(x)+D(x) \end{aligned}

其中,[x^n]A(x)=\lfloor \frac{[x^n]F(x)}{bas} \rfloor[x^n]B(x)=[x^n]F(x)\bmod bas,对于 C(x),D(x) 同理。这样系数就被降低到了 10^5bas^2 的级别。那么有

\begin{aligned} F(x)G(x)&=(basA(x)+B(x))(basC(x)+D(x))\\ &=bas^2A(x)C(x)+bas(A(x)D(x)+B(x)C(x))+B(x)D(x) \end{aligned}

这样先算出 A(x),B(x),C(x),D(x) 的点值表示,然后根据乘的系数不同做三次 IDFT,这样一次多项式乘法一共要做 7 次 FFT。

我们发现 DFT 的时候虚部一开始全为 0,可以通过一些方法利用上这里,减少 DFT 次数。

首先我们要明白 DFT 的本质是 F(w_n^k)=\sum\limits_{i=0}^nf_iw_n^{ki} ,假设我们现在要算出 A(x)B(x) 的点值表示。

如果我们知道了一个数组使得 F[k]=A(w_n^k)+iB(w_n^k)G[k]=A(w_n^k)-iB(w_n^k),那么我们可以通过解方程组的方式知道 A(w_n^k)B(w_n^k)

现在我们尝试推导 FG 的关系,通过 DFT 的定义可以知道:

\begin{aligned} F[k]&=A(w_n^k)+iB(w_n^k)\\ &=\sum\limits_{j=0}^{n-1}a_iw_n^{kj}+i\sum\limits_{i=0}^nb_jw_n^{kj}\\ &=\sum\limits_{j=0}^{n-1}{n-1}(a_j+ib_j)w_n^{kj}\\ &=\sum\limits_{j=0}^{n-1}(a_j+ib_j)(\cos(X)+i\sin(X))\\ &=\sum\limits_{j=0}^{n-1}(a_j\cos(X)-b_j\sin(X))+i(b_j\cos(X)+a_j\sin(X))\\ \end{aligned}

其中,X=\frac{2\pi kj}{n}

同时,我们有

\begin{aligned} G[n-k]&=A(w_n^{n-k})-iB(w_n^{n-k})\\ &=A(w_n^{-k})-iB(w_n^{-k})\\ &=\sum\limits_{j=0}^{n-1}a_iw_n^{-kj}-i\sum\limits_{i=0}^nb_jw_n^{-kj}\\ &=\sum\limits_{j=0}^{n-1}{n-1}(a_j-ib_j)w_n^{-kj}\\ &=\sum\limits_{j=0}^{n-1}(a_j-ib_j)(\cos(X)-i\sin(X))\\ &=\sum\limits_{j=0}^{n-1}(a_j\cos(X)-b_j\sin(X))-i(b_j\cos(X)+a_j\sin(X))\\ \end{aligned}

所以,我们发现 F[k]G[n-k] 是共轭的(k=0 时要特判一下放到 G[0] 去),所以我们只需要计算出 F,就可以 O(n) 计算出 G。这样只需要 2 次 FFT,就可以计算出四个多项式的点值。

然后考虑 IDFT 部分,这里也可以省一次 FFT。我们正常求 IDFT 的时候最后结果全部都在实部(不明白去看单位根反演),那么如果我们在最开始每个点值都乘上 i,那么最后出来的结果应该都在虚部,而且数值和放在实部上相等。我们知道 FFT 是一个线性变换,所以我们可以把要求的两个多项式分别放在实部和虚部上,然后 IDFT 之后实部和虚部上就是分别的原多项式了。这样就又省了一次 FFT,一次多项式乘法总共需要 4 次。

实际使用一般设 bas2 的次幂,比较方便提取系数。
这份代码实现不够精细,没有预处理单位根,导致精度不足必须开 long double


#include<bits/stdc++.h>
using namespace std;
namespace poly
{
    long double const pi=acos(-1);
    struct comp
    {
        long double r,i;
        comp(){r=i=0;}
        comp(long double x,long double y){r=x,i=y;}
        comp conj(){return comp(r,-i);}
        friend comp operator +(comp x,comp y){return comp(x.r+y.r,x.i+y.i);}
        friend comp operator -(comp x,comp y){return comp(x.r-y.r,x.i-y.i);}
        friend comp operator *(comp x,comp y){return comp(x.r*y.r-x.i*y.i,x.i*y.r+x.r*y.i);}
    };
    typedef long long ll;
    int r[400005];
    comp a[400005],b[400005],c[400005],d[400005];
    void fft(comp *f,int n,int op)
    {
        for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)+((i&1)?(n>>1):0);
        for(int i=1;i<n;i++)if(i<r[i])swap(f[i],f[r[i]]);
        for(int len=2;len<=n;len<<=1)
        {
            int q=len>>1;
            comp wn=comp(cos(pi/q),op*sin(pi/q));
            for(int i=0;i<n;i+=len)
            {
                comp w=comp(1,0);
                for(int j=i;j<i+q;j++,w=w*wn)
                {
                    comp d=f[j+q]*w;
                    f[j+q]=f[j]-d;
                    f[j]=f[j]+d;
                }
            }
        }
    }
    void mtt(int *f,int *g,int *h,int n,int p)
    {
        for(int i=0;i<n;i++)
            a[i].r=(f[i]>>15),a[i].i=(f[i]&32767),
            c[i].r=(g[i]>>15),c[i].i=(g[i]&32767);
        fft(a,n,1),fft(c,n,1);
        for(int i=1;i<n;i++)b[i]=a[n-i].conj();
        b[0]=a[0].conj();
        for(int i=1;i<n;i++)d[i]=c[n-i].conj();
        d[0]=c[0].conj();
        for(int i=0;i<n;i++)
        {
            comp 
            aa=(a[i]+b[i])*comp(0.5,0),
            bb=(a[i]-b[i])*comp(0,-0.5),
            cc=(c[i]+d[i])*comp(0.5,0),
            dd=(c[i]-d[i])*comp(0,-0.5);
            a[i]=aa*cc+comp(0,1)*(aa*dd+bb*cc),b[i]=bb*dd;
        }
        fft(a,n,-1),fft(b,n,-1);
        for(int i=0;i<n;i++)
        {
            int 
            aa=(ll)(a[i].r/n+0.5)%p,
            bb=(ll)(a[i].i/n+0.5)%p,
            cc=(ll)(b[i].r/n+0.5)%p;
            h[i]=((1ll*aa*(1<<30)+1ll*bb*(1<<15)+cc)%p+p)%p;
        }
    }
}
using namespace poly;
int f[400005],g[400005],h[400005];
int main()
{
    int n,m,p;
    scanf("%d%d%d",&n,&m,&p);
    for(int i=0;i<=n;i++)scanf("%d",&f[i]);
    for(int i=0;i<=m;i++)scanf("%d",&g[i]);
    int lim=1;
    while(lim<=(n+m))lim<<=1;
    mtt(f,g,h,lim,p);
    for(int i=0;i<=n+m;i++)printf("%d ",h[i]);
    return 0;
}