FFT 简记

· · 算法·理论

P3803 【模板】多项式乘法(FFT)

Part 1

多项式有两种表示。

系数表示法:f(x)=a_0x^0+a_1x^1+...+a_{n-1}x^{n-1}

点值表示法:(x_0,f(x_0)),(x_1,f(x_1)),...,(x_{n-1},f(n-1))

关注到求点值表示法的乘积是容易的,故 FFT 的目标是,把输入的系数表示法转为点值表示法,求乘积,再将点值表示法转为系数表示法输出。

FFT 中,取 x_i=\omega_n^i。其中 \omega_n 为单位圆的第一个 n 分点(注意!此处及他处的 n 均为 2 的正整数次幂!)。

对于 \omega,我们有如下式子。

一、\omega_n=\cos\frac{2\pi}{n}+i\sin\frac{2\pi}{n}

二、\omega_n^n=\omega_n^0=1

三、\omega_n^k=\omega_{2n}^{2k}

四、\omega_n^{\frac{n}{2}}=-1

Part 2

f(x)=a_0x^0+a_1x^1+...+a_{n-1}x^{n-1}

欲求 f(x) 的点值表示法。

又令 f_0(x)=a_0x^0+a_2x^1+...+a_{n-2}x^{\frac{n}{2}-1}

再令 f_1(x)=a_1x^0+a_3x^1+...+a_{n-1}x^{\frac{n}{2}-1}

我们有 f(x)=f_0(x^2)+f_1(x^2)x

从而 f(\omega_n^i)=f_0(\omega_{\frac{n}{2}}^{i})+f_1(\omega_{\frac{n}{2}}^{i})\omega_n^i

进而只需求出 $f_0(x)$ 和 $f_1(x)$ 的点值表示法。 有 $T(n)=n+2T(\frac{n}{2})$,得 $T(n)=n\log_2n$。 于是可在 $O(n\log_2n)$ 的时间复杂度内求出 $f(x)$ 的点值表示法。 ## Part 3 不妨令 $f(x)$ 的点值表示法为 $(\omega_n^0,y_0),(\omega_n^1,y_1),...,(\omega_n^{n-1},y_{n-1})$。 又令 $c_0,c_1,...,c_n$ 使 $c_k=\sum\limits_{i=0}^{n-1}(\omega_n^{-k})^{i}y_i$。 则 $c_k=\sum\limits_{i=0}^{n-1}(\omega_n^{-k})^{i}y_i \;\;\;\;\;\;\;\;=\sum\limits_{i=0}^{n-1}\sum\limits_{j=0}^{n-1}(\omega_n^{-k})^{i}a_{j}(\omega_n^i)^j \;\;\;\;\;\;\;\;=\sum\limits_{j=0}^{n-1}a_j\sum\limits_{i=0}^{n-1}(\omega^{j-k})^i \;\;\;\;\;\;\;\;=a_kn

从而对 y_0,y_1,...,y_{n-1} 再作一次 FFT 即可得出系数表示法。

Part 4

#include<bits/stdc++.h>
using namespace std;
#define dIO_USE_BUFFER
struct IO{
#ifdef dIO_USE_BUFFER
const static int BUFSIZE=1<<20;char ibuf[BUFSIZE],obuf[BUFSIZE],*p1,*p2,*pp;inline int getchar(){return(p1==p2&&(p2=(p1=ibuf)+fread(ibuf,1,BUFSIZE,stdin),p1==p2)?EOF:*p1++);}inline int putchar(char x){return((pp-obuf==BUFSIZE&&(fwrite(obuf,1,BUFSIZE,stdout),pp=obuf)),*pp=x,pp++),x;}inline IO&flush(){return fwrite(obuf,1,pp-obuf,stdout),pp=obuf,fflush(stdout),*this;}IO(){p1=p2=ibuf,pp=obuf;}~IO(){flush();}
#else
int(*getchar)()=&::getchar;int(*putchar)(int)=&::putchar;inline IO&flush(){return fflush(stdout),*this;}
#endif
string _sep=" ";int k=2;template<typename Tp,typename enable_if<is_integral<Tp>::value||is_same<Tp,__int128_t>::value>::type* =nullptr>inline int read(Tp&s){int f=1,ch=getchar();s=0;while(!isdigit(ch)&&ch!=EOF)f=(ch=='-'?-1:1),ch=getchar();if(ch==EOF)return false;while(ch=='0')ch=getchar();while(isdigit(ch))s=s*10+(ch^48),ch=getchar();s*=f;return true;}template<typename Tp,typename enable_if<is_floating_point<Tp>::value>::type* =nullptr>inline int read(Tp&s){int f=1,ch=getchar();s=0;while(!isdigit(ch)&&ch!='.'&&ch!=EOF)f=(ch=='-'?-1:1),ch=getchar();if(ch==EOF)return false;while(isdigit(ch))s=s*10+(ch^48),ch=getchar();if(ch=='.'){Tp eps=0.1;ch=getchar();while(isdigit(ch))s=s+(ch^48)*eps,ch=getchar(),eps/=10;}s*=f;return true;}inline int read(char&ch){ch=getchar();while(isspace(ch)&&ch!=EOF)ch=getchar();return ch!=EOF;}inline int read(char*c){char ch=getchar(),*s=c;while(isspace(ch)&&ch!=EOF)ch=getchar();while(!isspace(ch)&&ch!=EOF)*(c++)=ch,ch=getchar();*c='\0';return s!=c;}inline int read(string&s){s.clear();char ch=getchar();while(isspace(ch)&&ch!=EOF)ch=getchar();while(!isspace(ch)&&ch!=EOF)s+=ch,ch=getchar();return s.size()>0;}template<typename Tp=int>inline Tp read(){Tp x;read(x);return x;}template<typename Tp,typename...Ts>inline int read(Tp&x,Ts&...val){return read(x)&&read(val...);}inline int getline(char*c,const char&ed='\n'){char ch=getchar(),*s=c;while(ch!=ed&&ch!=EOF)*(c++)=ch,ch=getchar();*c='\0';return s!=c;}inline int getline(string&s,const char&ed='\n'){s.clear();char ch=getchar();while(ch!=ed&&ch!=EOF)s+=ch,ch=getchar();return s.size()>0;}template<typename Tp,typename enable_if<is_integral<Tp>::value||is_same<Tp,__int128_t>::value>::type* =nullptr>inline IO&write(Tp x){if(x<0)putchar('-'),x=-x;static char sta[41];int top=0;do sta[top++]=x%10^48,x/=10;while(x);while(top)putchar(sta[--top]);return*this;}inline IO&write(const string&str){for(char ch:str)putchar(ch);return*this;}inline IO&write(const char*str){while(*str!='\0')putchar(*(str++));return*this;}inline IO&write(char*str){return write((const char*)str);}inline IO&write(const char&ch){return putchar(ch),*this;}template<typename Tp,typename enable_if<is_floating_point<Tp>::value>::type* =nullptr>inline IO&write(Tp x){if(x>1e18||x<-1e18){write("[Floating point overflow]");throw;}if(x<0)putchar('-'),x=-x;const static long long pow10[]={1,10,100,1000,10000,100000,1000000,10000000,100000000,1000000000,10000000000,100000000000,1000000000000,10000000000000,100000000000000,1000000000000000,10000000000000000,100000000000000000,100000000000000000,100000000000000000};const auto&n=pow10[k];long long whole=x;double tmp=(x-whole)*n;long long frac=tmp;double diff=tmp-frac;if(diff>0.5){++frac;if(frac>=n)frac=0,++whole;}else if(diff==0.5&&((frac==0U)||(frac&1U)))++frac;write(whole);if(k==0U){diff=x-whole;if((!(diff<0.5)||(diff>0.5))&&(whole&1))++whole;}else{putchar('.');static char sta[21];int count=k,top=0;while(frac){sta[top++]=frac%10^48;frac/=10,count--;}while(count--)putchar('0');while(top)putchar(sta[--top]);}return*this;}template<typename Tp,typename...Ts>inline IO&write(Tp x,Ts...val){return write(x),write(_sep),write(val...),*this;}template<typename...Ts>inline IO&writeln(Ts...val){return write(val...),putchar('\n'),*this;}template<typename...Ts>inline IO&writesp(Ts...val){return write(val...),putchar(' '),*this;}inline IO&writeln(void){return putchar('\n'),*this;}inline IO&sep(const string&s=" "){return _sep=s,*this;}inline IO&prec(const int&K=2){return k=K,*this;}}io;
const int N=4e6+5;
const double PI=acos(-1);
struct Complex{
    double x,y;
    Complex(){}
    Complex(double x,double y):x(x),y(y){}
    Complex operator+(Complex A){
        return Complex(x+A.x,y+A.y);
    }
    Complex operator-(Complex A){
        return Complex(x-A.x,y-A.y);
    }
    Complex operator*(Complex A){
        return Complex(x*A.x-y*A.y,x*A.y+y*A.x);
    }
} f[N],g[N];
int n,m;
void FFT(Complex *a,int limit,int type){
    if(limit==1) return;
    Complex a1[limit>>1],a2[limit>>1];
    for(int i=0;i<limit-1;i+=2) a1[i>>1]=a[i],a2[i>>1]=a[i+1];
    FFT(a1,limit>>1,type);
    FFT(a2,limit>>1,type);
    Complex w=Complex(cos(PI*2/limit),type*sin(PI*2/limit));
    Complex rw=Complex(1,0);
    for(int i=0;i<(limit>>1);i++){
        a[i]=a1[i]+a2[i]*rw;
        a[i+(limit>>1)]=a1[i]-a2[i]*rw;
        rw=rw*w;
    } 
}
int main(){
    io.read(n,m),++n,++m;
    for(int i=0;i<n;i++) io.read(f[i].x);
    for(int i=0;i<m;i++) io.read(g[i].x); 
    int lim=1;
    while(lim<n+m) lim<<=1;
    FFT(f,lim,1);
    FFT(g,lim,1);
    for(int i=0;i<lim;i++) f[i]=f[i]*g[i];
    FFT(f,lim,-1);
    for(int i=0;i<n+m-1;i++) io.write((int)(f[i].x/lim+0.5)),io.putchar(32);
    return 0;
}