题解 P3803 【【模板】多项式乘法(FFT)】

NaCly_Fish

2019-05-21 00:25:42

Solution

发现FFT的模板题竟然没有三次变两次优化呢qwq。。 这里就讲一下吧。 建议在看这篇题解之前,先理解基本的FFT。 **** 按照一般的做法,如果我们要求出$F(x)$和$G(x)$的卷积,会把它们分别FFT,然后对应每项乘起来,最后再逆FFT回来。 来回一共进行了三次FFT,感觉这样是很不优的。 我们可以把$G(x)$放到$F(x)$的虚部上去,求出$F(x)^2$,然后把$F(x)$的虚部取出来除$2$就是答案了qwq 正确性的证明也很简单: $$\large (a+bi)^2=(a^2-b^2)+(2abi)$$ 这样我们就把常数优化到原来的$2/3$啦! ps:跑得比NTT还快! 最后上代码: ```cpp #include<cstdio> #include<iostream> #include<cstring> #include<algorithm> #include<vector> #include<cmath> #define ll long long #define N 4000009 #define reg register #define pi 3.141592653589793 using namespace std; struct complex{ double x,y; complex(double x=0,double y=0):x(x),y(y){} inline complex operator + (const complex b) const{ return complex(x+b.x,y+b.y); } inline complex operator - (const complex b) const{ return complex(x-b.x,y-b.y); } inline complex operator * (const complex b) const{ complex res; res.x = x*b.x-y*b.y; res.y = x*b.y+y*b.x; return res; } }; inline void read(int &x); void print(int x); inline void FFT(complex *a,int type,int lim); int n,m; complex F[N]; int rev[N]; signed main(){ int qwq,t,lim = 1,l = -1; read(n),read(m); for(reg int i=0;i<=n;++i) read(t),F[i].x = t; for(reg int i=0;i<=m;++i) read(t),F[i].y = t; //把G(x)放到F(x)的虚部上 t = n+m; n = max(n,m); while(lim<=(n<<1)){ lim <<= 1; ++l; } for(reg int i=1;i<=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<l); FFT(F,1,lim); for(reg int i=0;i<=lim;++i) F[i] = F[i]*F[i]; //求出F(x)^2 FFT(F,-1,lim); for(reg int i=0;i<=t;++i){ qwq = F[i].y/2+0.5; //虚部取出来除2,注意要+0.5,否则精度会有问题 print(qwq); putchar(' '); } return 0; } inline void FFT(complex *a,int type,int lim){ //没什么好说的,FFT的模板 for(reg int i=1;i<=lim;++i){ if(i>=rev[i]) continue; swap(a[i],a[rev[i]]); } reg complex rt,w,x,y; for(reg int mid=1;mid<lim;mid<<=1){ reg int r = mid<<1; rt = complex(cos(pi/mid),type*sin(pi/mid)); for(reg int j=0;j<lim;j+=r){ w = complex(1,0); for(reg int k=0;k<mid;++k){ x = a[j|k]; y = w*a[j|k|mid]; a[j|k] = x+y; a[j|k|mid] = x-y; w = w*rt; } } } if(type==1) return; for(reg int i=0;i<=lim;++i){ a[i].y = a[i].y/lim; a[i].x = a[i].x/lim; } } inline void read(int &x){ x = 0; char c = getchar(); while(c<'0'||c>'9') c = getchar(); while(c>='0'&&c<='9'){ x = (x<<3)+(x<<1)+(c^48); c = getchar(); } } void print(int x){ if(x>9) print(x/10); putchar(x%10+'0'); } ```