P3803 【模板】多项式乘法(FFT)

· · 题解

前置知识

为方便理解,我先讲 FFT 的作用,然后再解释原理。

作用

知周所众,一个一元 n 次多项式有两种表示方式。第一种是常见的系数表示法,如下:

a_0+a_1x+a_2x^2+\cdots+a_nx^n

还有一种是不那么常见的点值表示法,如下:

(x_0,y_0),(x_1,y_1),(x_2,y_2),\dots,(x_n,y_n)

第二种表示方法常见的例子有:用平面直角坐标系上两点确定一条直线,或者三点确定一条抛物线。它的本质就是取 n+1 个不同的值带入多项式,得出结果,以此确定一个多项式。

FFT 的作用就是在 O(n\log n) 的时间复杂度内将一个一元 n 次多项式在两种表示法中转换。

那么如果是点值表示法的两个多项式相乘呢?首先要保证 x_i 是对应相等的,然后 y_i 相乘即可,时间复杂度 O(n)。是不是快了很多?

然而有个残酷的事情是,如果朴素的将系数表示法转换为点值表示法,即一一带入数字并计算结果,那么时间复杂度依然是 O(n^2),起不到任何优化。

这时,FFT 的作用就体现出来了。它可以优化这一部分复杂度至 O(n\log n),大大加快了效率。与它作用相似的还有 NTT,感兴趣的读者可以自行上网查询资料。

原理

FFT 通过将复数单位根的整数次幂带入多项式,分治快速求解。至于为什么选择它,自然是因为单位根具有一些美妙的性质。

单位根

数学上,n 次单位根是 n 次幂为1的复数。它们位于复平面的单位圆上,构成正 n 边形的顶点,其中一个顶点是 1

摘自百度百科。

什么是单位圆呢?

就是圆心为原点,半径为单位长度的圆,如图所示:

而如果我们将单位圆 n 等分,且其中一个等分点在 (1,0) 处,那么从实轴开始逆时针遇到的第二个点表示的复数即为单位根,设为 \omega_n 。如图为 n=8 时的情况:

容易发现这些点按逆时针可依次表示为 \omega_8^0\omega_8^1\omega_8^2\cdots\omega_8^7,而 \omega_8^8=\omega_8^0=1。这符合上文提到的单位根的定义。

还有两个十分重要的发现是当 n2 的整数次幂时,\omega_n^i=-\omega_n^{i+\frac{n}{2}},以及当 n 为偶数时,\omega_n^{2i}=\omega_{\frac{n}{2}}^i,至于严格的证明我不会,可以自行上网查阅。

FFT 要求 n 必须为 2 的正整数次幂。即多项式最简形式的项数必须为 2 的正整数次幂。

实现

递归版

这时有巨佬就要发问了,这些性质有什么用?

可以举个例子(来自 oi-wiki)。

设一个一元 7 次多项式如下:

f(x)=a_0+a_1x+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7

将其各项按奇偶性分为两组,并在奇数组提出一个 x

f(x)=(a_0+a_2x^2+a_4x^4+a_6x^6)+(a_1x+a_3x^3+a_5x^5+a_7x^7)\\ =(a_0+a_2x^2+a_4x^4+a_6x^6)+x(a_1+a_3x^2+a_5x^4+a_7x^6)\\ =g(x^2)+x\cdot h(x^2)

其中 g,h 的形式与 f 类似,只是项数少了一半。

如果 x=\omega_8^k 呢?

由上文中得到的性质可得:

f(\omega_8^k)=g((\omega_8^k)^2)+\omega_8^k\cdot h((\omega_8^k)^2)\\ =g(\omega_8^{2k})+\omega_8^k\cdot h(\omega_8^{2k})\\ =g(\omega_4^k)+\omega_8^k\cdot h(\omega_4^k) f(\omega_8^{k+4})=g((-\omega_8^k)^2)-\omega_8^k\cdot h((-\omega_8^k)^2)\\ =g(\omega_8^{2k})-\omega_8^k\cdot h(\omega_8^{2k})\\ =g(\omega_4^k)-\omega_8^k\cdot h(\omega_4^k)

发现可以分治递归下去。

递归版代码

由此就得到了递归版 FFT。

#define cpd complex<dd>//c++自带复数STL模板
cpd tmp[N];
void FFT(cpd *f, ll n) {//递归版
//f[i]表示f(w_n^i)
    if (n == 1)return; //此时的多项式内不含未知数,无需带入
    for (int i = 0; i < n; i++) {
        if (i & 1)
            tmp[i / 2 + n / 2] = f[i];
        else
            tmp[i / 2] = f[i];
    }
    for (int i = 0; i < n; i++)
        f[i] = tmp[i];
    cpd *g = f,*h = f + n / 2; //节约空间
    FFT(g, n / 2), FFT(h, n / 2);
    cpd wn = {cos(2 * pi / n), sin(2 * pi / n)}; //单位根
    //2*pi/n为弧度制中单位根的辐角大小
    cpd w = {1, 0}; //单位根的k次幂
    for (int k = 0; k < n / 2; k++) {
        cpd u = g[k], v = w * h[k];
        f[k] = u + v;
        f[k + n / 2] = u - v;
        w *= wn;
    }
}

非递归版

递归版常数巨大,效率过低,所以就有了非递归版。

首先可以发现在递归到最底层之前没有对数据本身进行任何操作,所以可以想到提前把数据放到底层的位置,然后进行合并操作。

一次完整递归如图所示,大括号表示一组的范围:

可以发现递归前后的位置编号为其二进制下各位颠倒后的数。证明我不会

于是可以先预处理出各位最后的位置,然后合并求解。

至于如何预处理,设 rev_i 表示下标 i 转换后的位置。当枚举到 i 时,rev_{\lfloor\frac{i}{2}\rfloor} 一定被求过,此时将 rev_{\lfloor\frac{i}{2}\rfloor} 右移一位得到的数即为 rev_i 不考虑首位的结果。然后判断 i 是否为奇数决定首位即可。

原因是 \lfloor\frac{i}{2}\rfloor 就相当于去掉 i 的二进制末位。

其实递归版也可以这样优化,但是实现较为麻烦,而且效率依旧不及非递归版,本文不再赘述。

非递归版代码

然后就得到了非递归版 FFT。

#define cpd complex<dd>
void init() {
    ll pos = 1ll << (k - 1);
    for (ll i = 1; i < len; i++) {
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1)rev[i] += pos;
    }
}
inline void change(cpd f[]) {
    for (ll i = 0; i < len; i++) {
        if (i >= rev[i])continue;//防止转换两次又到原位
        swap(f[i], f[rev[i]]);
    }
}
inline void fft(cpd f[]) {//非递归版
//f[i]表示f(w_n^i)
    change(f);
    for (ll n = 2; n <= len; n <<= 1) {//枚举块的大小
        cpd wn = {cos(2 * pi / n), sin(2 * pi / n)};
        //2*pi/n为弧度制中单位根的辐角大小
        for (ll j = 0; j < len; j += n) {//遍历每一个块
            cpd w = {1, 0};
            for (ll i = j; i < j + n / 2; i++) {
                cpd g = f[i], h = w * f[i + n / 2];
                f[i] = g + h;
                f[i + n / 2] = g - h;
                w *= wn;
            }
        }
    }
}

IFFT

通过上面的讲解,可能有巨佬已经发现了。我们只讲了如何从系数表示法转换为点值表示法,而没有讲如何从点值表示法转换回系数表示法。而我们最终需要的是系数表示法的结果。

那么 IFFT(快速傅里叶逆变换)就派上用场了。

它其实就是将公式中的 \omega_n^k 替换为 \omega_n^{-k},然后做一遍 FFT,再将最终得到的结果除以 n 即可。

也可以理解为顺时针沿着单位圆转圈。

至于为什么这样做是对的呢?

考虑原本的多项式为:

f(x)=\sum_{i=0}^{n-1}a_ix^i

对其做两遍 FFT,第二遍相当于相当于将第一遍的结果视为另外一个多项式的系数,然后再做 FFT。

所以设:

y_i=f(w_n^i)\\ p(x)=\sum_{i=0}^{n-1}y_ix^i

\omega_n^k 的倒数 \omega_n^{-k} 带入 p 可得:

p(\omega_n^{-k})=\sum_{i=0}^{n-1}y_i\omega_n^{-ki}\\ =\sum_{i=0}^{n-1}\omega_n^{-ki}\sum_{j=0}^{n-1}a_j(\omega_n^{j})^i\\ =\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}\omega_n^{ij-ik}\\ =\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i

s(\omega_n^x)=\sum_{i=0}^{n-1}(\omega_n^x)^i,当 x=0(\mod n) 显然有 s(\omega_n^x)=n

x\neq 0(\mod n) 时,可得:

s(\omega_n^x)=\sum_{i=0}^{n-1}(\omega_n^x)^i\\ \omega_n^xs(\omega_n^x)=\sum_{i=1}^{n}(\omega_n^x)^i

错位相减,可得:

(\omega_n^x-1)s(\omega_n^x)=\omega_n^n-\omega_n^0=0

即:

s(\omega_n^x)=0

所以只有 j=k 的时候,\sum_{j=0}^{n-1}(\omega_n^{j-k})^i 不为 0,即 \sum_{j=0}^{n-1}(\omega_n^{j-k})^i=n。于是得:

p(\omega_n^{-k})=\sum_{i=0}^{n-1}a_j\sum_{j=0}^{n-1}(\omega_n^{j-k})^i\\ =\sum_{i=0}^{n-1}a_jn

由此可得,只要反着做一遍 FFT,然后将结果除以 n 即可。

可以将这两个操作合二为一,只要将单位根替换为单位根的倒数(即将其纵坐标变为原来的相反数),结束时特判除以 n 即可。(代码见完整代码)

完整代码

递归版和非递归版的代码都在这。

#include<bits/stdc++.h>
using namespace std;
#define ull unsigned long long
#define ll long long
#define ld long double
#define dd double
//char buf[1<<23],*p1=buf,*p2=buf;
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<23,stdin),p1==p2)?EOF:*p1++)
inline ll read() {
    ll x = 0, f = 1;
    char ch;
    while (((ch = getchar()) < 48 || ch > 57)&&ch!=EOF)if (ch == '-')f = -1;
    while (ch >= 48 && ch <= 57)x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}
char __sta[1009], __len;
inline void write(ll x,ll bo) {
    if (x < 0)putchar('-'), x = -x;
    do __sta[++__len] = x % 10 + 48, x /= 10;
    while (x);
    while (__len)putchar(__sta[__len--]);
    if(bo==3)return;
    putchar(bo ? '\n' : ' ');
}
#define cpd complex<dd>
const ll N=4e6+9;
const dd pi=acos(-1.0);
ll n,m,len;
cpd f[N],g[N];
ll rev[N];
void init(){
    n=read(),m=read();
    for(int i=0;i<=n;i++){
        f[i]={read(),0};
    }
    for(int i=0;i<=m;i++){
        g[i]={read(),0};
    }
    len=1;
    ll k=0;
    while(len<=n*2||len<=m*2)len<<=1,k++;
    ll pos=1<<(k-1);
    for(int i=1;i<len;i++){
        rev[i]=rev[i>>1]>>1;
        if(i&1)rev[i]+=pos;
    }
}
/*————————————华丽的分割线————————————*/
inline void change(cpd f[]){
    for(int i=0;i<len;i++){
        if(i>=rev[i])continue;//防止转换两次又到原位
        swap(f[i],f[rev[i]]);
    }
}
inline void fft(cpd f[],ll on){//非递归版
    change(f);
    for(int n=2;n<=len;n<<=1){//枚举块的大小
        cpd wn={cos(2*pi/n),sin(2*pi*on/n)};
        //2*pi/n为弧度制中单位根的辐角大小
        for(int j=0;j<len;j+=n){
            cpd w={1,0};
            for(int i=j;i<j+n/2;i++){
                cpd u=f[i],v=w*f[i+n/2];
                f[i]=u+v;
                f[i+n/2]=u-v;
                w*=wn;
            }
        }
    }
    if(on==-1){
        for(int i=0;i<len;i++){
            f[i]/=len;
        }
    }
}
/*————————————华丽的分割线————————————*/
cpd tmp[N];
void FFT(cpd *f, ll n, ll on) {//递归版
//f[i]表示f(w_n^i)
    if (n == 1)return; //此时的多项式内不含未知数,无需带入
    for (int i = 0; i < n; i++) {
        if (i & 1)
            tmp[i / 2 + n / 2] = f[i];
        else
            tmp[i / 2] = f[i];
    }
    for (int i = 0; i < n; i++)
        f[i] = tmp[i];
    cpd *g = f,*h = f + n / 2; //节约空间
    FFT(g, n / 2, on), FFT(h, n / 2, on);
    cpd wn = {cos(2 * pi / n), sin(on * 2 * pi / n)}; //单位根
    //2*pi/n为弧度制中单位根的辐角大小
    cpd w = {1, 0}; //单位根的k次幂
    for (int k = 0; k < n / 2; k++) {
        cpd u = g[k], v = w * h[k];
        f[k] = u + v;
        f[k + n / 2] = u - v;
        w *= wn;
    }
}
inline void fix(cpd *f) {
    for (int i = 0; i < len; i++)f[i] /= len;
}
/*————————————华丽的分割线————————————*/
void solve(){
    for(int i=0;i<len;i++){
        f[i]*=g[i];
    }
}
void finish(){
    for(int i=0;i<=n+m;i++){
        write(f[i].real()+0.5,0);
    }
}
int main(){
    init();
//  fft(f,1),fft(g,1);

    FFT(f,len,1),FFT(g,len,1);
    solve();
//  fft(f,-1);

    FFT(f,len,-1);
    fix(f);
    finish();
    return 0;
}

参考资料

oi-wiki。