快速傅里叶变换

· · 题解

快速傅里叶变换

update:2020.7.29 增加了三次变两次优化

【预告】

在解决多项式卷积(乘法)的问题时,直接相乘的O(n^2)做法似乎找不到优化的余地。

这时,我们就需要寻找其它的路径来计算卷积。

众所周知,二维坐标系上横坐标互异的n个点可以确定一个n-1次函数。

与用各项系数表示多项式的系数表示法不同,用点表示一个多项式的方法叫点值表示法。

离散傅里叶变换(DFT,Discrete Fourier Transformation)指的是多项式由系数表示法转为点值表示法的过程。

相对地,把一个多项式的点值表示法转化为系数表示法的过程,就是离散逆傅里叶变换(IDFT)。

快速傅里叶变换(Fast Fourier Transformation)指的是在执行DFTIDFT时通过代入特殊的值加速过程的一种算法。

【前置知识】

  1. 三角函数(不讲)

  2. 复数运算(不讲)

  3. 复平面及单位根(不想讲,百度有)

【数学推导部分】

首先,我们来看看单位根的性质

再看看这个多项式

为了方便讨论,以下提及的多项式的长度n均为2的整数次幂,即多项式次数为n-1,不够就往高位补0

我们在DFT中,其实要计算的的是

观察每一个n-1次的多项式:

按次数奇偶性分类,设

可以得到

不妨把单位根 \omega_n^k (k<\frac{n}{2}) 代进去看看

再代一个 \omega_{n}^{k+\frac{n}{2}}

还没发现问题?我把要求的式子再列出来

计算一半就能知道另一半的结果

原来O(n^2)的时间就会变成O(\frac{n^2}{2})

如果用递归来完成每一步这样的操作,将G(x)H(x)的系数也这样计算,由于递归树中的每一层分类、计算的总量复杂度是O(n),递归树最多logn层,所以这样就可以实现O(nlogn)计算这n个值了

递归式可以写出来。经过若干次分组之后,当前计算的多项式的长度变为2^t,设当前要计算的多项式为G_t(x)(也有可能是H_t(x))

每层递归中只需要写成下面这样

G_t(\omega_n^k)=G_{t-1}(\omega_\frac{n}{2}^k)+\omega_n^kH_{t-1}(\omega_\frac{n}{2}^k) G_t(\omega_n^{k+\frac{n}{2}})=G_{t-1}(\omega_\frac{n}{2}^k)-\omega_n^kH_{t-1}(\omega_\frac{n}{2}^k)

其中G_0(k)=a_k\omega_n^kk为图中编号而不是下标

如下图,以n=8为例,每一层进行分组,按上面的结论计算,蓝色代表相加,红色代表相减

不废话了,先来看看代码

顺便给出一个复数数据类型模板

const double pi=acos(-1.0);
struct Cplx{
    double Re,Im;
    Cplx(double p=0,double q=0):Re(p),Im(q){}
    Cplx operator + (const Cplx &a) const {return Cplx(Re+a.Re,Im+a.Im);}
    Cplx operator - (const Cplx &a) const {return Cplx(Re-a.Re,Im-a.Im);}
    Cplx operator * (const Cplx &a) const {return Cplx(Re*a.Re-Im*a.Im,Re*a.Im+Im*a.Re);}
}A[MAXN],B[MAXN],C[MAXN];
void FFT(Cplx *a,int n){
    if(n==1) return;
    int m=n>>1;Cplx a1[m],a2[m];
    for(int i=0;i<m;i++){//分类
        a1[i]=a[i<<1];
        a2[i]=a[i<<1|1];
    }
    FFT(a1,m);FFT(a2,m);//递归
    Cplx w1=Cplx(cos(2.0*pi/n),sin(2.0*pi/n)),wk=Cplx(1,0);
    for(int i=0;i<m;i++,wk=wk*w1){//计算
        a[i]=a1[i]+wk*a2[i];
        a[i+m]=a1[i]-wk*a2[i];
    }
    return;
}

讲了这么久,终于讲完了DFT,接下来就是IDFT

经过一波神奇的操作,IDFT的代码可以变得和DFT的代码相似度高达84.946\%

以下是公式推导时间

刚刚我们利用系数a_0,a_1,...,a_{n-1}计算出了F(\omega_n^0),F(\omega_n^1),...,F(\omega_n^{n-1})

相当于计算了序列\{b_k\},其中b_k=\sum\limits_{i=0}^{n-1}a_i(\omega_n^k)^i

再尝试一下对b_0,b_1,...,b_{n-1},将\omega_n^k换成\omega_n^{-k}再试一下

设一条序列\{c_k\},其中c_k=\sum\limits_{j=0}^{n-1}b_j(\omega_n^{-k})^j

那么

c_k=\sum\limits_{j=0}^{n-1}(\sum\limits_{i=0}^{n-1}a_i(\omega_n^j)^i)(\omega_n^{-k})^j =\sum\limits_{j=0}^{n-1}(\sum\limits_{i=0}^{n-1}a_i(\omega_n^i)^j)(\omega_n^{-k})^j =\sum\limits_{j=0}^{n-1}\sum\limits_{i=0}^{n-1}a_i(\omega_n^i)^j(\omega_n^{-k})^j =\sum\limits_{i=0}^{n-1}\sum\limits_{j=0}^{n-1}a_i(\omega_n^i)^j(\omega_n^{-k})^j =\sum\limits_{i=0}^{n-1}(a_i\sum\limits_{j=0}^{n-1}(\omega_n^{i-k})^j)

我们先看看这条式子S(x)=\sum\limits_{j=0}^{n-1}(\omega_n^{i-k})^j

i=k时显然有S(x)=n

i\not=k根据单位根的几何性质显然有S(x)=0

所以c_k=\sum\limits_{i=0}^{n-1}[i=k]na_i=na_k

a_k=\frac{c_k}{n}

设对于序列的运算FFT(\{a_k\})=\{\sum\limits_{k=0}^{n-1}a_k\omega_n^k\}

FFT^{-1}(\{a_k\})=\{\sum\limits_{k=0}^{n-1}a_k\omega_n^{-k}\}

FFT^{-1}(FFT(\{a_k\}))=\{na_k\}

至此,DFT和IDFT我们都找到了办法来实现

Quod$ $Erat$ $Demonstrandum

【递归实现代码】

先上一个AC不了模板题的递归代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;

const double pi=acos(-1.0);
const int MAXN=5000005;
struct Cplx{
    double Re,Im;
    Cplx(double p=0,double q=0):Re(p),Im(q){}
    Cplx operator + (const Cplx &a) const {return Cplx(Re+a.Re,Im+a.Im);}
    Cplx operator - (const Cplx &a) const {return Cplx(Re-a.Re,Im-a.Im);}
    Cplx operator * (const Cplx &a) const {return Cplx(Re*a.Re-Im*a.Im,Re*a.Im+Im*a.Re);}
}A[MAXN],B[MAXN],C[MAXN];
int N,M,L;
void FFT(Cplx *a,int n,int d){
    if(n==1) return;
    int m=n>>1;Cplx a1[m],a2[m];
    for(int i=0;i<m;i++){
        a1[i]=a[i<<1];
        a2[i]=a[i<<1|1];
    }
    FFT(a1,m,d);FFT(a2,m,d);
    Cplx w1=Cplx(cos(2.0*pi/n),d*sin(2.0*pi/n)),wk=Cplx(1,0);
    for(int i=0;i<m;i++,wk=wk*w1){
        a[i]=a1[i]+wk*a2[i];
        a[i+m]=a1[i]-wk*a2[i];
    }
    return;
}
void Input(){
    scanf("%d%d",&N,&M);
    for(int i=0;i<=N;i++) scanf("%lf",&A[i].Re);
    for(int i=0;i<=M;i++) scanf("%lf",&B[i].Re);
    for(L=1;L<N+M+1;L<<=1);
    return;
}
int main(){
    Input();
    FFT(A,L,1);
    FFT(B,L,1);
    for(int i=0;i<=L;i++)
        C[i]=A[i]*B[i];
    FFT(C,L,-1);
    for(int i=0;i<=N+M;i++)
        printf("%d ",(int)(C[i].Re/L+0.5));
    return 0;
}

【迭代优化】

复杂度没错,是O(nlogn),为甚会TLE?

我们需要一点小小的优化

再回来看看这个图

我们发现

  1. 可以先把多项式的系数排成图中最下面一层的顺序,再自底向上运算

  2. 计算时只会用到向下一层的数,完全可以只用同一个数组覆盖着计算

关于第一点,我们发现有一种神奇的方法可以O(n)计算出最底层的顺序

还是以n=8为例

最底层是\{rnk_i\}=\{0,4,2,6,1,5,3,7\}

恰好是\{0,1,2,3,4,5,6,7\}在二进制下的反转

像这样

(000)_2,(001)_2,(010)_2,(011)_2,(100)_2,(101)_2,(110)_2,(111)_2 (000)_2,(100)_2,(010)_2,(110)_2,(001)_2,(101)_2,(011)_2,(111)_2

根据这个原理,我们可以计算出\{rnk_i\}数组

先看一条式子,也许你就会O(n)计算\{rnk_i\}

rnk_i=rnk_\frac{n}{2}*2+[i \mod 2]

没错!这样算就是rnk_i=i

只要反过来

rnk_i=rnk_\frac{n}{2}/2+(n/2)*[i \mod 2=1]

就行了。不信自己推

并且在将系数排成\{rnk_i\}时这样写

for(int i=0;i<=n-1;i++)
    if(rnk[i]>i) swap(a[i],a[rnk[i]]);

就可以了

只要我们第一层循环枚举当前要合并的区间长度n

第二层枚举完每一个要合并的区间,其中i是当前处理的区间的左端点

再像递归的做法那样在[i,i+\frac{n}{2}]里枚举j

t1=a[j],t2=a[j+\frac{n}{2}]*\omega_n^k

a[j]=t1+t2,a[j+\frac{n}{2}]=t1-t2;

详情看代码

for(int n=2;n<=L;n<<=1){
    Cplx w1(cos(2.0*pi/n),d*sin(2.0*pi/n));
    for(int i=0;i<L;i+=n){
        Cplx wk(1,0);
        for(int j=i;j<i+(n>>1);j++,wk=wk*w1){
            Cplx t1=a[j],t2=a[j+(n>>1)]*wk;
            a[j]=t1+t2;
            a[j+(n>>1)]=t1-t2;
        }
    }
}

FFT特色:三次变两次优化

将B放到A的虚部上,得到

a+bi

用FFT计算它的平方,得

(a+bi)^2=(a^2-b^2)+(2ab)i

结果的虚部/2就是答案了。别忘了除以L。

这样可以少做一次FFT,常数小。

【模板代码】Luogu P3803

注意一下细节就能AC了

//Fast Fourier Transformation
#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;

const int MAXN=4000005;
const double pi=acos(-1.0);
struct Cplx{
    double Re,Im;
    Cplx(double re=0,double im=0):Re(re),Im(im){}
    Cplx operator + (const Cplx &a) const {return Cplx(Re+a.Re,Im+a.Im);}
    Cplx operator - (const Cplx &a) const {return Cplx(Re-a.Re,Im-a.Im);}
    Cplx operator * (const Cplx &a) const {return Cplx(Re*a.Re-Im*a.Im,Re*a.Im+Im*a.Re);}
}A[MAXN],B[MAXN],C[MAXN];
int N,M,L,rnk[MAXN];

void Input(){
    scanf("%d%d",&N,&M);
    for(int i=0;i<=N;i++) scanf("%lf",&A[i].Re);
    for(int j=0;j<=M;j++) scanf("%lf",&B[j].Re);
    for(L=1;L<N+M+1;L<<=1);
    for(int i=1;i<L;i++){
        rnk[i]=(rnk[i>>1]>>1);
        if(i&1) rnk[i]|=(L>>1);
    }
    return;
}

void FFT(Cplx *a,int d){
    for(int i=0;i<L;i++)
        if(i<rnk[i]) swap(a[i],a[rnk[i]]);
    for(int n=2;n<=L;n<<=1){
        Cplx w1(cos(2.0*pi/n),d*sin(2.0*pi/n));
        for(int i=0;i<L;i+=n){
            Cplx wk(1,0);
            for(int j=i;j<i+(n>>1);j++,wk=wk*w1){
                Cplx t1=a[j],t2=a[j+(n>>1)]*wk;
                a[j]=t1+t2;
                a[j+(n>>1)]=t1-t2;
            }
        }
    }
    return;
}

void mul1(){
    FFT(A,1);FFT(B,1);
    for(int i=0;i<L;i++)
        C[i]=A[i]*B[i];
    FFT(C,-1);
    for(int i=0;i<=N+M;i++)
        printf("%d ",(int)(C[i].Re/L+0.5));
    return;
}

//(a+bi)^2=(a^2-b^2)+(2ab)i
void mul2(){
    for(int i=0;i<L;i++)
        A[i].Im=B[i].Re;
    FFT(A,1);
    for(int i=0;i<L;i++)
        C[i]=A[i]*A[i];
    FFT(C,-1);
    for(int i=0;i<=N+M;i++)
        printf("%d ",(int)(C[i].Im/L/2.0+0.5));
    return;
}

int main(){
    Input();
    mul2();
    return 0;
}