题解 P3803 【【模板】多项式乘法(FFT)】
「学习笔记」FFT 快速傅里叶变换
几个星期之后,继 扩展欧拉定理 之后,
虽然听得心态爆炸, 但是还好的是没有
至少我还没有坐飞机...
啥是 FFT 呀?它可以干什么?
首先,你需要知道 矩阵乘法 的相关知识。
通过 矩阵乘法 的知识,我们知道,对于一个
似乎这是比较优秀的,但是看看这道 板题 :
luoguOJ:P3803【模板】多项式乘法(FFT)
嗯? (不要管这病句...)
而 FFT 就是用来解决多项式乘法的问题的,可以把它的时间复杂度优化到
这里引用一句话:
实际上,FFT 并不是直接计算多项式乘法,而是把原来的多项式
f(x),g(x) 在O(nlog^n_2) 的复杂度内转换为它的 <u>点值表示</u>(后面会讲),而点值表示的多项式相乘的时间复杂度是O(n) 的。最后再用O(nlog_2^n) 的时间复杂度把所得多项式的点值表示转化为一般形式。
必备芝士
在学习 FFT 之前,我们需要知道 点值表示 和 复数 。
点值表示
什么是点值表示
对于一个函数
这种表达方法我们叫做 一般形式 。
如果熟悉,那么可以确定一个东西,如果我们有在这个函数上的
当然,如果知道 拉格朗日插值法 的大佬,或许可以更好地理解这句话。
也就是说一个多项式与一个点值表示是一一对应的。
那么 FFT 完成的操作就是:
- 把已知的一个多项式转化成对应的点值表示;
- 把已知的点值表示转换成对应的多项式。
复杂度都是
O(nlog_2^n) 。
点值表示的乘法
那么,假设我们用
那么我们就知道
复数
复数的定义
另外,还有一些比较明显的性质:
以上这些结论都可以通过在单位圆上画出单位根来证明。
单位复根有什么用呢?因为
傅立叶正变换
有了这些辅助知识 没错,虽然你可能已经晕了,但是他们真的只是辅助知识 ,我们终于可以进行正题了。
所谓变换,那么一定有正也有逆,现在我我们先来掌握它的正变换。
FFT 的正变换实现,是基于对多项式进行奇偶项分开递归再合并的分治进行的。
对于 n-1 次多项式,我们选择插入 n 次单位根求出其点值表达式。
记多项式
再记
再记
有
令
在已知
因此,假如我们递归求解
时间复杂度是经典的
傅里叶逆变换
刚刚研究完正的,现在我们来研究逆变换,其实也比较好理解。
观察我们刚刚的插值过程,实际上就是进行了如下的矩阵乘法。
我们记上面的系数矩阵为
对于下面定义的
考虑
当
当
根据定义,
n 次单位根的n 次方都等于1
所以:
因此将这个结果代入最上面那个公式里面,有:
“这样,逆变换 就相当于把 正变换 过程中的
w_n^k 换成w_n^{-k} ,之后结果除以n 就可以了。”——摘自某博客。
……
还是有点难理解。比如为什么我们不直接把 TM 还是一个
FFT 的代码实现
我们有两个版本——递归、迭代,相信大家都也想到了吧?
毋庸置疑的,递归版本确实很好写,将
#include<cstdio>
#include<cmath>
#define rep(i,__l,__r) for(register int i=__l,i##_end_=__r;i<=i##_end_;++i)
#define fep(i,__l,__r) for(register int i=__l,i##_end_=__r;i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define Endl putchar('\n')
// #define FILEOI
// #define int long long
#ifdef FILEOI
#define MAXBUFFERSIZE 500000
inline char fgetc(){
static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
}
#undef MAXBUFFERSIZE
#define cg (c=fgetc())
#else
#define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
char c;bool f=0;
while(cg<'0'||'9'<c)f|=(c=='-');
for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
if(f)x=-x;
}
inline int qread(){
int x=0;char c;bool f=0;
while(cg<'0'||'9'<c)f|=(c=='-');
for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
return f?-x:x;
}
template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
if(x<0)return (void)(putchar('-'),fwrit(-x));
if(x>9)fwrit(x/10);putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}
const int MAXN=3e6;
const double Pi=acos(-1.0);
class task{
private:
struct cplx{
double vr,vi;//实部和虚部
cplx(const double R=0,const double I=0):vr(R),vi(I){}//构造函数
//------------------overload----------------//
cplx operator + (const cplx a)const{return cplx(vr+a.vr,vi+a.vi);}//重载加法
cplx operator - (const cplx a)const{return cplx(vr-a.vr,vi-a.vi);}
cplx operator * (const cplx a)const{return cplx(vr*a.vr-vi*a.vi,vr*a.vi+a.vr*vi);}
cplx operator / (const double var)const{return cplx(vr/var,vi/var);}
};
int n,m;
cplx a[MAXN+5],b[MAXN+5];
void fft(cplx* f,const int len,const short opt=1){
//opt==-1 : FFT 的逆变换
if(!len)return;
cplx f0[len+5],f1[len+5];
for(int i=0;i<len;++i)
f0[i]=f[i<<1],f1[i]=f[i<<1|1];
fft(f0,len>>1,opt);
fft(f1,len>>1,opt);
cplx w=cplx(cos(Pi/len),opt*sin(Pi/len)),buf=cplx(1,0);
for(int i=0;i<len;++i,buf=buf*w){
f[i]=f0[i]+buf*f1[i];
f[i+len]=f0[i]-buf*f1[i];
}
}
public:
inline void launch(){
qread(n,m);
rep(i,0,n)scanf("%lf",&a[i].vr);
rep(i,0,m)scanf("%lf",&b[i].vr);
for(m+=n,n=1;n<=m;n<<=1);
fft(a,n>>1);
fft(b,n>>1);
rep(i,0,n-1)a[i]=a[i]*b[i];
fft(a,n>>1,-1);
rep(i,0,m)writc((int)((a[i].vr)/n+0.5),' ');
Endl;
}
}This;
signed main(){
#ifdef FILEOI
freopen("file.in","r",stdin);
freopen("file.out","w",stdout);
#endif
This.launch();
return 0;
}
似乎递归版本比较好写,现在我们来看一下迭代(递推)版本应该怎么做:
原序列:
终序列:
转换为二进制再来看看。
原序列:
终序列:
可以发现终序列是原序列每个元素的二进制翻转。
于是我们可以先把要变换的系数排在相邻位置,从下往上迭代。
这个二进制翻转过程可以自己脑补方法,只要保证时间复杂度
在这里给出一个参考的方法:
我们对于每个
这是迭代版本:
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
#define rep(i,__l,__r) for(register int i=__l,i##_end_=__r;i<=i##_end_;++i)
#define fep(i,__l,__r) for(register int i=__l,i##_end_=__r;i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define Endl putchar('\n')
// #define FILEOI
// #define int long long
#ifdef FILEOI
#define MAXBUFFERSIZE 500000
inline char fgetc(){
static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
}
#undef MAXBUFFERSIZE
#define cg (c=fgetc())
#else
#define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
char c;bool f=0;
while(cg<'0'||'9'<c)f|=(c=='-');
for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
if(f)x=-x;
}
inline int qread(){
int x=0;char c;bool f=0;
while(cg<'0'||'9'<c)f|=(c=='-');
for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
return f?-x:x;
}
template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
if(x<0)return (void)(putchar('-'),fwrit(-x));
if(x>9)fwrit(x/10);putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}
const int MAXN=3e6;
const double Pi=acos(-1.0);
class task{
private:
struct cplx{
double vr,vi;//实部和虚部
cplx(const double R=0,const double I=0):vr(R),vi(I){}//构造函数
//------------------overload----------------//
cplx operator + (const cplx a)const{return cplx(vr+a.vr,vi+a.vi);}//重载加法
cplx operator - (const cplx a)const{return cplx(vr-a.vr,vi-a.vi);}
cplx operator * (const cplx a)const{return cplx(vr*a.vr-vi*a.vi,vr*a.vi+a.vr*vi);}
cplx operator / (const double var)const{return cplx(vr/var,vi/var);}
};
int n,m;
cplx a[MAXN+5],b[MAXN+5];
int revi[MAXN+5];
/*
f(w^x)
=f0(w^{2x})+w^x*f1(w^{2x})
=a0+a2+a4+...,a1+a3+a5+...
=a0+a4+a8+...+a2+a6+a10...+a1+a5+a9+...+a3+a7+a11...
=f00(w^{4x})+w^{2x}*f01(w^{4x})+w^x*f10(w^{4x})+w^3x*f11(w^{4x})
=a0+a4+a8+...,w^{2x}*(a2+a6+a10...),w^x*(a1+a5+a9+...),w^{3x}*(a3+a7+a11...)
f_s -> 下标将 s 反过来之后, a_i 的 i 的二进制反过来与 s 相同
.
.
.
s
000 001 010 011
|反过来
v
000 100 010 110
a0 a{k/2} a{3*k/4} a{k}
a0 a4 a2 a6
*/
/*
void fft(cplx* f,const int len,const short opt=1){
if(!len)return;
cplx f0[len+5],f1[len+5];
for(int i=0;i<len;++i)
f0[i]=f[i<<1],f1[i]=f[i<<1|1];
fft(f0,len>>1,opt);
fft(f1,len>>1,opt);
cplx w=cplx(cos(Pi/len),opt*sin(Pi/len)),buf=cplx(1,0);
for(int i=0;i<len;++i,buf=buf*w){
f[i]=f0[i]+buf*f1[i];
f[i+len]=f0[i]-buf*f1[i];
}
}
*/
inline void fft(cplx* f,const short opt=1){
for(int i=0;i<n;++i)if(i<revi[i])
swap(f[i],f[revi[i]]);
for(int p=2;p<=n;p<<=1){
//枚举层数
int len=p/2;//上一层的一半的长度
cplx tmp(cos(Pi/len),opt*sin(Pi/len));
//单位复根
for(int k=0;k<n;k+=p){
cplx buf(1,0);//记录 omega 的次方
for(int l=k;l<k+len;++l,buf=buf*tmp){
//每次 buf 累成单位 omega
cplx tt=buf*f[len+l];
f[len+l]=f[l]-tt;
f[l]=f[l]+tt;
/*
此处与递归版本的 FFT 中这一段是一样的:
f[i]=f0[i]+buf*f1[i];
f[i+len]=f0[i]-buf*f1[i];
*/
}
}
}
if(opt==-1)for(int i=0;i<n;++i)f[i]=f[i]/n;
//如果是 逆变换 , 那么需要全部 /n
}
public:
inline void launch(){
qread(n,m);
rep(i,0,n)scanf("%lf",&a[i].vr);
rep(i,0,m)scanf("%lf",&b[i].vr);
for(m+=n,n=1;n<=m;n<<=1);
rep(i,0,n-1)revi[i]=(revi[i>>1]>>1)|((i&1)?n>>1:0);//处理反转
fft(a);fft(b);
rep(i,0,n-1)a[i]=a[i]*b[i];
fft(a,-1);
rep(i,0,m)writc((int)(a[i].vr+0.5),' ');
Endl;
}
}This;
signed main(){
#ifdef FILEOI
freopen("file.in","r",stdin);
freopen("file.out","w",stdout);
#endif
This.launch();
return 0;
}
如果我可以回到过去,我一定会去当杀手。
为什么?因为我要去干翻欧某和傅某某...