题解 P3803 【【模板】多项式乘法(FFT)】
更好的阅读体验请点这里
Intro:
本篇博客将会从朴素乘法讲起,经过分治乘法,到达FFT和NTT
旨在能够让读者(也让自己)充分理解其思想
朴素乘法
约定:两个多项式为A(x)=\sum_{i=0}^{n}a_ix^i,B(x)=\sum_{i=0}^{m}b_ix^i
Prerequisite knowledge:
初中数学知识(手动滑稽)
最简单的多项式方法就是逐项相乘再合并同类项,写成公式:
若
于是一个朴素乘法就产生了,见代码(利用某种丧心病狂的方式省了
//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
register int x;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,x=c&15;else k=x=0;
while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
return k?x:-x;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define N (2000010)
int n,m,a[N],b,c[N];
signed main(){
Rd(n),Rd(m);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m){Rd(b);Frn1(j,0,n)c[i+j]+=b*a[j];}
Frn1(i,0,n+m)wr(c[i]),Ps;
exit(0);
}
Time complexity:
Memory complexity:
看看效果
意料之中,所以必须优化
朴素分治乘法
P.s 这一部分讲述了FFT的分治方法,与FFT还是有区别的,如果已经理解的可以跳过
约定:n 为同时属于A(x),B(x) 次数界的最小的2 的正整数幂,并将两个多项式设为A(x)=\sum_{i=0}^{n-1}a_ix^i,B(x)=\sum_{i=0}^{n-1}b_ix^i ,不存在的系数补零
次数界:严格
Reference:
《算法导论》
Prerequisite knowledge:
分治思想
现在来考虑如何去优化乘法
尝试将两个多项式按照未知项次数的奇偶性分开:
其中
于是两个多项式就被拆成了两个次数界为
P.s 以下的公式中,用
在此可以发现一种分治算法:把两个多项式折半,然后再递归算
P.s 注意合并方式:
(为了省空间用了vector)
//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
register int x;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,x=c&15;else k=x=0;
while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
return k?x:-x;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
typedef vector<int> Vct;
int n,m,s;
Vct a,b,c;
void add(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]+b[i];}
void mlt(Vct&a,Vct&b,Vct&c,int n);
signed main(){
Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
mlt(a,b,c,s);
Frn1(i,0,n+m)wr(c[i]),Ps;
exit(0);
}
void mlt(Vct&a,Vct&b,Vct&c,int n){
int n2(n>>1);
Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
if(n==1){c[0]=a[0]*b[0];return;}
Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
mlt(a0,b0,ab0,n2),mlt(a1,b1,ab1,n2);
Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
mlt(a0,b1,ab0,n2),mlt(a1,b0,ab1,n2),add(ab0,ab1,abm);
Frn0(i,0,n-1)c[i<<1|1]=abm[i];
}
看看效果
好像更惨……
为什么呢,因为这个算法的时间复杂度还是
而且不仅复杂度高,常数因子也因为递归变高了
所以继续优化吧……
分治乘法
接上上一部分的内容,考虑如何优化时间复杂度
先来一个小插曲:如何只做
先看看结果:
所以如果只用
尝试把
答案出来了,用
回到原题目
于是中间项也可以使用类似的方法:
成功减少一次乘法运算,见代码
//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
register int x;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,x=c&15;else k=x=0;
while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
return k?x:-x;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
typedef vector<int> Vct;
int n,m,s;
Vct a,b,c;
void add(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]+b[i];}
void mns(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]-b[i];}
void mlt(Vct&a,Vct&b,Vct&c);
signed main(){
Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
mlt(a,b,c);
Frn1(i,0,n+m)wr(c[i]),Ps;
exit(0);
}
void mlt(Vct&a,Vct&b,Vct&c){
int n(a.size()),n2(a.size()>>1);
Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
if(n==1){c[0]=a[0]*b[0];return;}
Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
mlt(a0,b0,ab0),mlt(a1,b1,ab1);
Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
add(a0,a1,a0),add(b0,b1,b0),mlt(a0,b0,abm),mns(abm,ab0,abm),mns(abm,ab1,abm);
Frn0(i,0,n-1)c[i<<1|1]=abm[i];
}
看看效果
比朴素分治乘法好一点,但是还是没朴素乘法强,还是很惨
看看这个算法的时间复杂度:
运用主方法,
额那不是应该比朴素算法要好吗,这是什么情况
Reason 1. 分治乘法的常数因子太大
Reason 2. 打开
所以就要请上本文的主角了
快速傅里叶变换 FFT (Fast Fourier Transform)
Fairly Frightening Transform
约定:
Reference:
《算法导论》
自为风月马前卒:快速傅里叶变换(FFT)详解
Prerequisite knowledge:
分治思想
复数的基本知识
线性代数的基本知识
Part 1: 多项式的两种表示方式
1. 系数表达
对一个次数界为
使用系数表达时,下列操作的时间复杂度:
-
求值
O(n) -
加法
O(n) -
乘法朴素
O(n^2) ,优化(n^{\log_2 3}) (即分治乘法)
2. 点值表达
一个次数界为
进行
其中左边的矩阵表示为
使用拉格朗日公式,可以在
对于两个在相同位置求值的点值表达多项式,下列操作的时间复杂度:
-
加法
O(n) (只要将各个位置的y 值相加即可) -
乘法
O(n) (同理)
所以这就是使用FFT的原因:通过精心选取
傅里叶大神究竟选了什么神奇的
Part 2: 单位复数根及其性质
其中
可以把
如图表示的是
于是就可以得到规律:
接下来的三个引理就是FFT的重头戏啦
1. 消去引理:对任何整数n\geqslant 0,k\geqslant 0,d>0 ,有\omega_{dn}^{dk}=\omega_n^k
Proof:
2. 折半引理:对任何偶数n 和整数k ,有(\omega_n^k)^2=(\omega_n^{k+n/2})^2=\omega_{n/2}^k
Proof:
3. 求和引理:对任何整数n\geqslant 0 与非负整数k:n\nmid k ,有\sum_{j=0}^{n-1}(\omega_n^k)^j=0
Proof: 利用等比数列求和公式,
Part 3: 离散傅里叶变换 DFT (Discrete Fourier Transform)
DFT就是将次数界为
简化一下表示方法:
用公式表示就是
另外,可以发现
终于可以看看具体操作了
Part 4: FFT
FFT利用单位根的特殊性质把DFT优化到了
和分治乘法一样,按未知项次数的奇偶性分开:
其中
这时,求
1. 求A^{[0]}(x) 与A^{[1]}(x) 在(\omega_n^0)^2,(\omega_n^1)^2,\cdots,(\omega_n^{n-1})^2 的值
根据折半引理,
所以只要对拆开的两个多项式分别做
2. 合并答案
所以
具体运行时,就每次循环结束时让一个初始为
递归边界:n=1 ,那么w_1^0 a_0=a_0 ,所以直接返回自身
计算一下时间复杂度
定理:对i,j=0,1,\cdots,n-1 ,有[V_n^{-1}]_{ij}=\omega_n^{-ij}/n
Proof: 证明
如果
否则,因为
接下来
比较一下DFT中
只要运算时把
终于可以来到激动人心的实现环节了
Part 6: 递归实现
根据前文,只要将分治乘法的代码修改一下即可
可以做到直接在原址进行FFT,就是将分开的两个多项式分置在左右两边
STL提供了现成的complex类可供使用
代码中用
P.s 最后别忘了
//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
register int u;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,u=c&15;else k=u=0;
while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
return k?u:-u;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
double const Pi(acos(-1));
typedef complex<double> Cpx;
#define N (2100000)
Cpx o,w,a[N],b[N],tmp[N],x,y;
int n,m,s;
bool iv;
void fft(Cpx*a,int n);
signed main(){
Rd(n),Rd(m),s=1<<int(log2(n+m)+1);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
fft(a,s),fft(b,s);
Frn0(i,0,s)a[i]*=b[i];
iv=1,fft(a,s);
Frn1(i,0,n+m)wr(a[i].real()/s+0.5),Ps;
exit(0);
}
void fft(Cpx*a,int n){
if(n==1)return;
int n2(n>>1);
Frn0(i,0,n2)tmp[i]=a[i<<1],tmp[i+n2]=a[i<<1|1];
copy(tmp,tmp+n,a),fft(a,n2),fft(a+n2,n2);
o={cos(Pi/n2),(iv?-1:1)*sin(Pi/n2)},w=1;
Frn0(i,0,n2)x=a[i],y=w*a[i+n2],a[i]=x+y,a[i+n2]=x-y,w*=o;
}
Time complexity:
Memory complexity:
看看效果
性能已经超过了朴素乘法(必然的),但是还是没有AC
注意到
Part 6: 迭代实现
设l=\lceil\log_2(n+m+1)\rceil,s=2^l ,那么A(x),B(x),A(x)B(x) 都是次数界为s 的多项式
现在需要寻找到一种迭代的方式,使答案自底向上合并以减少常数因子
还是像递归版一样,把
观察每一层递归时各个系数所在位置的规律,以
0-> 0 1 2 3 4 5 6 7
1-> 0 2 4 6|1 3 5 7
2-> 0 4|2 6|1 5|3 7
end 0|4|2|6|1|5|3|7
没看出来?那就拆成二进制看看
0-> 000 001 010 011 100 101 110 111
1-> 000 010 100 110|001 011 101 111
2-> 000 100|010 110|001 101|011 111
end 000|100|010|110|001|101|011|111
显然地在最后一层递归时,系数编号正好是位置编号的反转(更准确的说是前
一个较为感性的Proof: 因为是按照奇偶性分类,也就是说在第
所以到了递归最底层,位置编号的二进制就正好是系数编号二进制前
构造数组(或看代码)
蝴蝶操作 (Butterfly Operation)
其实在递归版代码中已经出现,但是这里再详细说明一下
还记得
但是现在不使用
因为按照奇偶性分置在两边,所以
设
那么新的
这就是蝴蝶操作啦
有了蝴蝶操作,只要将所有系数按照
在代码中,用
//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
register int u;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,u=c&15;else k=u=0;
while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
return k?u:-u;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
double const Pi(acos(-1));
typedef complex<double> Cpx;
#define N (2100000)
Cpx a[N],b[N],o,w,x,y;
int n,m,l,s,r[N];
void fft(Cpx*a,bool iv);
signed main(){
Rd(n),Rd(m),s=1<<(l=log2(n+m)+1);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
fft(a,0),fft(b,0);
Frn0(i,0,s)a[i]*=b[i];
fft(a,1);
Frn1(i,0,n+m)wr(a[i].real()+0.5),Ps;
exit(0);
}
void fft(Cpx*a,bool iv){
Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
for(int i(2),i2(1);i<=s;i2=i,i<<=1){
o={cos(Pi/i2),(iv?-1:1)*sin(Pi/i2)};
for(int j(0);j<s;j+=i){
w=1;
Frn0(k,0,i2){
x=a[j+k],y=w*a[j+k+i2];
a[j+k]=x+y,a[j+k+i2]=x-y,w*=o;
}
}
}
if(iv)Frn0(i,0,s)a[i]/=s;
}
Time complexity:
Memory complexity:
看看效果
终于……
到现在为止FFT的内容已经全部结束啦,下面是拓展部分
Extension: 快速数论变换 NTT (Number Theoretic Transform)
虽然FFT具有优秀的时间复杂度,但因为用到了复数,不可避免会出现精度问题
如果多项式系数和结果都是一定范围非负整数,可以考虑使用NTT来优化精度和时空常数
Reference:
《算法导论》
自为风月马前卒:快速数论变换(NTT)小结
Prerequisite knowledge:
FFT(必须知道的)
模运算基本知识
原根的性质
现在考虑所有运算都在
设有正整数
E.g 对于
1-> {1}
2-> {1,2,4}
3-> {1,2,3,4,5,6}
4-> {1,2,4}
5-> {1,2,3,4,5,6}
6-> {1,6}
所以
在代码中,一般使用大质数
原根的特点就是它的次幂以长度为
E.g
那么
这个特性和单位根非常相似
但是要完全替换单位根,还差一步
单位根的代替品
在FFT中使用的是循环长度为
所以为了让循环长度为
离散对数定理:如果g 是Z_P^* 的一个原根,则x\equiv y(\mod\phi(P))\iff g^x\equiv g^y(\mod P)
Proof: 设
因此
反过来,因为循环长度是
现在考虑有一个
也就是说对任意整数
即
即
可得
那么为了使
此时
因为
所以
那么问题来了,万一
这就引出了大质数
而根据数据范围
总结一下,
消去引理(只需考虑
最后只要把
//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
register int u;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,u=c&15;else k=u=0;
while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
return k?u:-u;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define P (998244353)
#define G (3)
#define Gi (332748118)
#define N (2100000)
int n,m,l,s,r[N],a[N],b[N],o,w,x,y,siv;
int fpw(int a,int p){return p?a>>1?(p&1?a:1)*fpw(a*a%P,p>>1)%P:a:1;}
void ntt(int*a,bool iv);
signed main(){
Rd(n),Rd(m),siv=fpw(s=1<<(l=log2(n+m)+1),P-2);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
ntt(a,0),ntt(b,0);
Frn0(i,0,s)a[i]=a[i]*b[i]%P;
ntt(a,1);
Frn1(i,0,n+m)wr(a[i]),Ps;
exit(0);
}
void ntt(int*a,bool iv){
Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
for(int i(2),i2(1);i<=s;i2=i,i<<=1){
o=fpw(iv?Gi:G,(P-1)/i);
for(int j(0);j<s;j+=i){
w=1;
Frn0(k,0,i2){
x=a[j+k],y=w*a[j+k+i2]%P;
a[j+k]=(x+y)%P,a[j+k+i2]=(x-y+P)%P,w=w*o%P;
}
}
}
if(iv)Frn0(i,0,s)a[i]=a[i]*siv%P;
}
Time complexity:
Memory complexity:
看看效果
时间上的提升效果不大,但是空间少了一半(因为用了int而不是complex)
Conclusion:
打了一天的博客终于写完了(好累)
但是对FFT和NTT的理解也加深了不少
这个算法对数学知识和分治思想的要求都很高
本蒟蒻花了近一年的时间才真正理解
如果有错误和意见请大佬多多指教
那么本篇博客就到这里啦,谢谢各位大佬的支持!ありがとう!