题解 P4245 【【模板】任意模数NTT】
如果不取模,按 OI 常见数据范围来讲两个多项式相乘以后系数不超过
众所周知,FFT 中不用取模,只需要结果取一次模就可以了。但是这题最大的系数达到了
设
其中,
这样先算出
我们发现 DFT 的时候虚部一开始全为 0,可以通过一些方法利用上这里,减少 DFT 次数。
首先我们要明白 DFT 的本质是
如果我们知道了一个数组使得
现在我们尝试推导
其中,
同时,我们有
所以,我们发现
然后考虑 IDFT 部分,这里也可以省一次 FFT。我们正常求 IDFT 的时候最后结果全部都在实部(不明白去看单位根反演),那么如果我们在最开始每个点值都乘上
实际使用一般设
这份代码实现不够精细,没有预处理单位根,导致精度不足必须开 long double
#include<bits/stdc++.h>
using namespace std;
namespace poly
{
long double const pi=acos(-1);
struct comp
{
long double r,i;
comp(){r=i=0;}
comp(long double x,long double y){r=x,i=y;}
comp conj(){return comp(r,-i);}
friend comp operator +(comp x,comp y){return comp(x.r+y.r,x.i+y.i);}
friend comp operator -(comp x,comp y){return comp(x.r-y.r,x.i-y.i);}
friend comp operator *(comp x,comp y){return comp(x.r*y.r-x.i*y.i,x.i*y.r+x.r*y.i);}
};
typedef long long ll;
int r[400005];
comp a[400005],b[400005],c[400005],d[400005];
void fft(comp *f,int n,int op)
{
for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)+((i&1)?(n>>1):0);
for(int i=1;i<n;i++)if(i<r[i])swap(f[i],f[r[i]]);
for(int len=2;len<=n;len<<=1)
{
int q=len>>1;
comp wn=comp(cos(pi/q),op*sin(pi/q));
for(int i=0;i<n;i+=len)
{
comp w=comp(1,0);
for(int j=i;j<i+q;j++,w=w*wn)
{
comp d=f[j+q]*w;
f[j+q]=f[j]-d;
f[j]=f[j]+d;
}
}
}
}
void mtt(int *f,int *g,int *h,int n,int p)
{
for(int i=0;i<n;i++)
a[i].r=(f[i]>>15),a[i].i=(f[i]&32767),
c[i].r=(g[i]>>15),c[i].i=(g[i]&32767);
fft(a,n,1),fft(c,n,1);
for(int i=1;i<n;i++)b[i]=a[n-i].conj();
b[0]=a[0].conj();
for(int i=1;i<n;i++)d[i]=c[n-i].conj();
d[0]=c[0].conj();
for(int i=0;i<n;i++)
{
comp
aa=(a[i]+b[i])*comp(0.5,0),
bb=(a[i]-b[i])*comp(0,-0.5),
cc=(c[i]+d[i])*comp(0.5,0),
dd=(c[i]-d[i])*comp(0,-0.5);
a[i]=aa*cc+comp(0,1)*(aa*dd+bb*cc),b[i]=bb*dd;
}
fft(a,n,-1),fft(b,n,-1);
for(int i=0;i<n;i++)
{
int
aa=(ll)(a[i].r/n+0.5)%p,
bb=(ll)(a[i].i/n+0.5)%p,
cc=(ll)(b[i].r/n+0.5)%p;
h[i]=((1ll*aa*(1<<30)+1ll*bb*(1<<15)+cc)%p+p)%p;
}
}
}
using namespace poly;
int f[400005],g[400005],h[400005];
int main()
{
int n,m,p;
scanf("%d%d%d",&n,&m,&p);
for(int i=0;i<=n;i++)scanf("%d",&f[i]);
for(int i=0;i<=m;i++)scanf("%d",&g[i]);
int lim=1;
while(lim<=(n+m))lim<<=1;
mtt(f,g,h,lim,p);
for(int i=0;i<=n+m;i++)printf("%d ",h[i]);
return 0;
}