P3803 【模板】多项式乘法(FFT)
前置知识
- 复数。
- 至少高一的数学基础。
FFT 简介
FFT,即快速傅里叶变换。
为方便理解,我先讲 FFT 的作用,然后再解释原理。
作用
知周所众,一个一元
还有一种是不那么常见的点值表示法,如下:
第二种表示方法常见的例子有:用平面直角坐标系上两点确定一条直线,或者三点确定一条抛物线。它的本质就是取
FFT 的作用就是在
那么如果是点值表示法的两个多项式相乘呢?首先要保证
然而有个残酷的事情是,如果朴素的将系数表示法转换为点值表示法,即一一带入数字并计算结果,那么时间复杂度依然是
这时,FFT 的作用就体现出来了。它可以优化这一部分复杂度至
原理
FFT 通过将复数单位根的整数次幂带入多项式,分治快速求解。至于为什么选择它,自然是因为单位根具有一些美妙的性质。
单位根
数学上,
n 次单位根是n 次幂为1的复数。它们位于复平面的单位圆上,构成正n 边形的顶点,其中一个顶点是1 。
摘自百度百科。
什么是单位圆呢?
就是圆心为原点,半径为单位长度的圆,如图所示:
而如果我们将单位圆
容易发现这些点按逆时针可依次表示为
还有两个十分重要的发现是当 我不会,可以自行上网查阅。
FFT 要求
实现
递归版
这时有巨佬就要发问了,这些性质有什么用?
可以举个例子(来自 oi-wiki)。
设一个一元
将其各项按奇偶性分为两组,并在奇数组提出一个
其中
如果
由上文中得到的性质可得:
发现可以分治递归下去。
递归版代码
由此就得到了递归版 FFT。
#define cpd complex<dd>//c++自带复数STL模板
cpd tmp[N];
void FFT(cpd *f, ll n) {//递归版
//f[i]表示f(w_n^i)
if (n == 1)return; //此时的多项式内不含未知数,无需带入
for (int i = 0; i < n; i++) {
if (i & 1)
tmp[i / 2 + n / 2] = f[i];
else
tmp[i / 2] = f[i];
}
for (int i = 0; i < n; i++)
f[i] = tmp[i];
cpd *g = f,*h = f + n / 2; //节约空间
FFT(g, n / 2), FFT(h, n / 2);
cpd wn = {cos(2 * pi / n), sin(2 * pi / n)}; //单位根
//2*pi/n为弧度制中单位根的辐角大小
cpd w = {1, 0}; //单位根的k次幂
for (int k = 0; k < n / 2; k++) {
cpd u = g[k], v = w * h[k];
f[k] = u + v;
f[k + n / 2] = u - v;
w *= wn;
}
}
非递归版
递归版常数巨大,效率过低,所以就有了非递归版。
首先可以发现在递归到最底层之前没有对数据本身进行任何操作,所以可以想到提前把数据放到底层的位置,然后进行合并操作。
一次完整递归如图所示,大括号表示一组的范围:
可以发现递归前后的位置编号为其二进制下各位颠倒后的数。证明我不会。
于是可以先预处理出各位最后的位置,然后合并求解。
至于如何预处理,设
原因是
其实递归版也可以这样优化,但是实现较为麻烦,而且效率依旧不及非递归版,本文不再赘述。
非递归版代码
然后就得到了非递归版 FFT。
#define cpd complex<dd>
void init() {
ll pos = 1ll << (k - 1);
for (ll i = 1; i < len; i++) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1)rev[i] += pos;
}
}
inline void change(cpd f[]) {
for (ll i = 0; i < len; i++) {
if (i >= rev[i])continue;//防止转换两次又到原位
swap(f[i], f[rev[i]]);
}
}
inline void fft(cpd f[]) {//非递归版
//f[i]表示f(w_n^i)
change(f);
for (ll n = 2; n <= len; n <<= 1) {//枚举块的大小
cpd wn = {cos(2 * pi / n), sin(2 * pi / n)};
//2*pi/n为弧度制中单位根的辐角大小
for (ll j = 0; j < len; j += n) {//遍历每一个块
cpd w = {1, 0};
for (ll i = j; i < j + n / 2; i++) {
cpd g = f[i], h = w * f[i + n / 2];
f[i] = g + h;
f[i + n / 2] = g - h;
w *= wn;
}
}
}
}
IFFT
通过上面的讲解,可能有巨佬已经发现了。我们只讲了如何从系数表示法转换为点值表示法,而没有讲如何从点值表示法转换回系数表示法。而我们最终需要的是系数表示法的结果。
那么 IFFT(快速傅里叶逆变换)就派上用场了。
它其实就是将公式中的
也可以理解为顺时针沿着单位圆转圈。
至于为什么这样做是对的呢?
考虑原本的多项式为:
对其做两遍 FFT,第二遍相当于相当于将第一遍的结果视为另外一个多项式的系数,然后再做 FFT。
所以设:
将
设
当
错位相减,可得:
即:
所以只有
由此可得,只要反着做一遍 FFT,然后将结果除以
可以将这两个操作合二为一,只要将单位根替换为单位根的倒数(即将其纵坐标变为原来的相反数),结束时特判除以
完整代码
递归版和非递归版的代码都在这。
#include<bits/stdc++.h>
using namespace std;
#define ull unsigned long long
#define ll long long
#define ld long double
#define dd double
//char buf[1<<23],*p1=buf,*p2=buf;
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<23,stdin),p1==p2)?EOF:*p1++)
inline ll read() {
ll x = 0, f = 1;
char ch;
while (((ch = getchar()) < 48 || ch > 57)&&ch!=EOF)if (ch == '-')f = -1;
while (ch >= 48 && ch <= 57)x = x * 10 + ch - 48, ch = getchar();
return x * f;
}
char __sta[1009], __len;
inline void write(ll x,ll bo) {
if (x < 0)putchar('-'), x = -x;
do __sta[++__len] = x % 10 + 48, x /= 10;
while (x);
while (__len)putchar(__sta[__len--]);
if(bo==3)return;
putchar(bo ? '\n' : ' ');
}
#define cpd complex<dd>
const ll N=4e6+9;
const dd pi=acos(-1.0);
ll n,m,len;
cpd f[N],g[N];
ll rev[N];
void init(){
n=read(),m=read();
for(int i=0;i<=n;i++){
f[i]={read(),0};
}
for(int i=0;i<=m;i++){
g[i]={read(),0};
}
len=1;
ll k=0;
while(len<=n*2||len<=m*2)len<<=1,k++;
ll pos=1<<(k-1);
for(int i=1;i<len;i++){
rev[i]=rev[i>>1]>>1;
if(i&1)rev[i]+=pos;
}
}
/*————————————华丽的分割线————————————*/
inline void change(cpd f[]){
for(int i=0;i<len;i++){
if(i>=rev[i])continue;//防止转换两次又到原位
swap(f[i],f[rev[i]]);
}
}
inline void fft(cpd f[],ll on){//非递归版
change(f);
for(int n=2;n<=len;n<<=1){//枚举块的大小
cpd wn={cos(2*pi/n),sin(2*pi*on/n)};
//2*pi/n为弧度制中单位根的辐角大小
for(int j=0;j<len;j+=n){
cpd w={1,0};
for(int i=j;i<j+n/2;i++){
cpd u=f[i],v=w*f[i+n/2];
f[i]=u+v;
f[i+n/2]=u-v;
w*=wn;
}
}
}
if(on==-1){
for(int i=0;i<len;i++){
f[i]/=len;
}
}
}
/*————————————华丽的分割线————————————*/
cpd tmp[N];
void FFT(cpd *f, ll n, ll on) {//递归版
//f[i]表示f(w_n^i)
if (n == 1)return; //此时的多项式内不含未知数,无需带入
for (int i = 0; i < n; i++) {
if (i & 1)
tmp[i / 2 + n / 2] = f[i];
else
tmp[i / 2] = f[i];
}
for (int i = 0; i < n; i++)
f[i] = tmp[i];
cpd *g = f,*h = f + n / 2; //节约空间
FFT(g, n / 2, on), FFT(h, n / 2, on);
cpd wn = {cos(2 * pi / n), sin(on * 2 * pi / n)}; //单位根
//2*pi/n为弧度制中单位根的辐角大小
cpd w = {1, 0}; //单位根的k次幂
for (int k = 0; k < n / 2; k++) {
cpd u = g[k], v = w * h[k];
f[k] = u + v;
f[k + n / 2] = u - v;
w *= wn;
}
}
inline void fix(cpd *f) {
for (int i = 0; i < len; i++)f[i] /= len;
}
/*————————————华丽的分割线————————————*/
void solve(){
for(int i=0;i<len;i++){
f[i]*=g[i];
}
}
void finish(){
for(int i=0;i<=n+m;i++){
write(f[i].real()+0.5,0);
}
}
int main(){
init();
// fft(f,1),fft(g,1);
FFT(f,len,1),FFT(g,len,1);
solve();
// fft(f,-1);
FFT(f,len,-1);
fix(f);
finish();
return 0;
}
参考资料
oi-wiki。