如何速通卷积?

· · 算法·理论

可能更好的阅读体验

有的题里面卷积是必要的,不会卷积就可能被暴打。

本文旨在帮助和我一样没怎么学多项式的人速通卷积。

其中可能有一些定义和结论,你不需要关心其证明也可以学会卷积,因此本文中不会证明结论。

点值表示法

通过系数表示法给出两个多项式(即给出各项系数) f(x)=a_0x^0+\dots+a_nx^n,g(x)=b_0x^0+\dots+b_mx^m,求 h(x)=f(x)g(x)=c_0x^0+\dots+c_{n+m}x^{n+m} 即其乘积的各项系数。

结论 1:根据 n 次多项式 f(x)n+1 个不同 x 处的取值 (x_1,y_1),(x_2,y_2),\dots,(x_{n+1},y_{n+1}) 可以唯一确定 f(x)

定义 1:根据结论 1 可以用 n+1 个不同 x 处的取值表示一个 n 次多项式,将这种表示方法称为点值表示法。

因此可以先求出 f(x),g(x)n+m+1 个不同 x 处的取值,然后相乘即可得到 h(x)n+m+1 个不同 x 处的取值,再根据这些值求出 h(x) 的各项系数。

于是现在问题变为了在系数表示法和点值表示法之间快速转化。

系数表示法 -> 点值表示法

直接暴力算即可做到 O((n+m)^2),但是显然不够快。

f(x)=f_0(x^2)+xf_1(x^2),即将其偶数次系数和奇数次系数分别拿出来组成新的多项式 f_0(x),f_1(x)

那么只要快速合并即可分治,为了分治可以将项数补到最小且 >n+m2 的整数次幂 2^p,但是合并好像很难。

单位根

不过注意到选的数是没有任何限制的,所以不妨找一些有特殊性质的数使其能够快速合并。

定义 2:令平面直角坐标系上的点 (x,y) 表示 x+iy,其中 i 是虚数单位满足 i^2=-1,将这个平面直角坐标系称为复平面。

复数运算:

typedef double db;
struct cpx{
    db x,y;
};
cpx operator + (const cpx &a,const cpx &b){
    return {a.x+b.x,a.y+b.y};
}
cpx operator - (const cpx &a,const cpx &b){
    return {a.x-b.x,a.y-b.y};
}
cpx operator * (const cpx &a,const cpx &b){
    return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
cpx operator / (const cpx &a,const cpx &b){
    return {(a.x*b.x+a.y*b.y)/(b.x*b.x+b.y*b.y),(a.y*b.x-a.x*b.y)/(b.x*b.x+b.y*b.y)};
}

定义 3:将平面直角坐标系上以原点为圆心单位长度为半径的圆称为单位圆。

定义 4:将复平面上的单位圆平均分为 n(n\ge2) 段且 (1,0) 为其中一个分段点,将从 (1,0) 开始逆时针走到的第 2 个分段点表示的数称为 \omega_n。根据三角函数基础知识,可知 \omega_n=\cos\frac{2\pi}{2^p}+i\sin\frac{2\pi}{2^p}

结论 2:\omega_n^k 对应从 (1,0) 开始逆时针走到的第 k+1 个分段点。

结论 3:当 2\mid n 时,-\omega_n^{k+\frac{n}{2}}=\omega_n^k

快速合并

不难发现 \omega_n^k 有一些良好性质,因此考虑令 x_i=\omega_{2^p}^{i-1}

于是可以注意到当 j>2^{p-1} 时,f(x_j)=f_0(x_j^2)+x_jf_1(x_j^2)=f_0(x_{j-2^{p-1}}^2)-x_{j-2^{p-1}}f_1(x_{j-2^{p-1}}^2)

因此只需求出 f_0(x_1),\dots,f_0(x_{2^{p-1}}),f_1(x_1),\dots,f_1(x_{2^{p-1}}) 即可,直接分治即可,时间复杂度 O(2^pp)=O((n+m)\log(n+m))

卡常

首先要把递归写成循环形式。

考虑将往下分的过程优化。(此过程中需要将偶数次系数和奇数次系数分到两边)

定义 5:将 i 在这个过程结束后移到的位置称为 to_i

结论 4:to_i 即为 i 的二进制表示将前 p 位 reverse 得到的数。

因此有递推式:to_i=\lfloor\frac{to_{\lfloor\frac{i}{2}\rfloor}}{2}\rfloor+[2\nmid i]2^{p-1}

于是可以将该过程优化到线性。

const db PI=acos(-1.0);
int to[N];
void fft(int len,cpx *a){
    rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
    rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
    for(int k=2;k<=len;k<<=1){
        cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
        for(int i=0;i<len;i+=k){
            cpx x={1,0};
            for(int j=0;j<(k>>1);j++){
                cpx p=a[i+j],q=a[i+j+(k>>1)]*x;
                a[i+j]=p+q,a[i+j+(k>>1)]=p-q;
                x=x*w;
            }
        }
    }
}

点值表示法 -> 系数表示法

直接根据上面代码倒推即可。

void ifft(int len,cpx *a){
    for(int k=len;k>=2;k>>=1){
        cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
        for(int i=0;i<len;i+=k){
            cpx x={1,0};
            for(int j=0;j<(k>>1);j++){
                cpx p=a[i+j],q=a[i+j+(k>>1)];
                a[i+j]=(p+q)/(cpx){2,0},a[i+j+(k>>1)]=(p-q)/(cpx){2,0}/x;
                x=x*w;
            }
        }
    }
    rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
    rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}

于是我们已经可以写出卷积代码了:

void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
    int len=1;
    while(len<=(n+m))len<<=1;
    rep(i,0,len-1)a[i]={(db)A[i],0.0};
    rep(i,0,len-1)b[i]={(db)B[i],0.0};
    fft(len,a);
    fft(len,b);
    rep(i,0,len-1)c[i]=a[i]*b[i];
    ifft(len,c);
    rep(i,0,len-1)C[i]=(ll)round(c[i].x);
}

三次变两次优化

原理:(a+bi)^2=(a^2-b^2)+(2ab)i

于是可以将 f(x),g(x) 的系数分别放在实部和虚部,求平方后虚部除以 2 便是 h(x)

cpx a[N];
void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
    int len=1;
    while(len<=(n+m))len<<=1;
    rep(i,0,len-1)a[i]={(db)A[i],(db)B[i]};
    fft(len,a);
    rep(i,0,len-1)a[i]=a[i]*a[i];
    ifft(len,a);
    rep(i,0,len-1)C[i]=(ll)round(a[i].y/2.0);
}

考虑模意义

显然三角函数与浮点数运算会产生精度误差,同时大多数情况下都是在特定模意义下使用卷积,因此考虑使用整数代替这些浮点数运算,只需要在特定模意义中找到和单位根有类似性质的数即可。

可以将 mod 分解,使用 CRT 合并即可。

一般 p-12 的较高整数次幂因子时可以使用。

原根

定义 6:对于奇质数 p,将满足 g^1,\dots,g^{p-1} 互不相同的 g 称为其原根。

结论 5:若 n 存在原根,则其最小原根是 O(n^\frac{1}{4}) 的。

结论 6:若 x 不为原根,则 \exists y,x^{\frac{p-1}{y}}\equiv 1 \pmod p

于是可以暴力枚举找最小原根。

## 代替单位根 结论 7:$g^{\frac{p-1}{2}}\equiv p-1\pmod p$。 因此考虑令 $x_i=(g^{\frac{mod-1}{2^p}})^{i-1}$。 于是可以注意到当 $j>2^{p-1}$ 时,$f(x_j)\equiv f_0(x_j^2)+x_jf_1(x_j^2)\equiv f_0(x_{j-2^{p-1}}^2)-x_{j-2^{p-1}}f_1(x_{j-2^{p-1}}^2)$。 因此只需求出 $f_0(x_1),\dots,f_0(x_{2^{p-1}}),f_1(x_1),\dots,f_1(x_{2^{p-1}})$ 即可,直接分治即可,时间复杂度 $O(2^pp)=O((n+m)\log(n+m))$。 ```cpp const ll mod=998244353; const ll I2=(mod+1)/2; const ll G=3; ll ksm(ll a,ll b,ll p){ a=a%p; ll r=1; while(b){ if(b&1)r=r*a%p; a=a*a%p; b>>=1; } return r%p; } const ll IG=ksm(G,mod-2,mod); void ntt(int len,ll *a){ rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1)); rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]); for(int k=2;k<=len;k<<=1){ ll w=ksm(G,(mod-1)/k,mod); for(int i=0;i<len;i+=k){ ll x=1; for(int j=0;j<(k>>1);j++){ ll p=a[i+j],q=a[i+j+(k>>1)]*x%mod; a[i+j]=(p+q)%mod,a[i+j+(k>>1)]=(p-q+mod)%mod; x=x*w%mod; } } } } void intt(int len,ll *a){ for(int k=len;k>=2;k>>=1){ ll w=ksm(IG,(mod-1)/k,mod); for(int i=0;i<len;i+=k){ ll x=1; for(int j=0;j<(k>>1);j++){ ll p=a[i+j],q=a[i+j+(k>>1)]; a[i+j]=(p+q)*I2%mod,a[i+j+(k>>1)]=(p-q+mod)*I2%mod*x%mod; x=x*w%mod; } } } rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1)); rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]); } ll ntt_a[N],ntt_b[N],ntt_c[N]; void convolution_ntt(int n,ll *A,int m,ll *B,ll *C){ int len=1; while(len<=(n+m))len<<=1; rep(i,0,len-1)ntt_a[i]=A[i]; rep(i,0,len-1)ntt_b[i]=B[i]; ntt(len,ntt_a); ntt(len,ntt_b); rep(i,0,len-1)ntt_c[i]=ntt_a[i]*ntt_b[i]%mod; intt(len,ntt_c); rep(i,0,len-1)C[i]=ntt_c[i]; } ``` ## 模板题代码 [题目链接](https://www.luogu.com.cn/problem/P3803) ```cpp #include<bits/stdc++.h> #define rep(i,l,r) for(int i=(l);i<=(r);i++) #define per(i,r,l) for(int i=(r);i>=(l);i--) #define repll(i,l,r) for(ll i=(l);i<=(r);i++) #define perll(i,r,l) for(ll i=(r);i>=(l);i--) #define pb push_back #define ins insert #define clr clear using namespace std; namespace ax_by_c{ typedef long long ll; const int N=4e6+5; namespace Bpoly{ typedef double db; struct cpx{ db x,y; }; cpx operator + (const cpx &a,const cpx &b){ return {a.x+b.x,a.y+b.y}; } cpx operator - (const cpx &a,const cpx &b){ return {a.x-b.x,a.y-b.y}; } cpx operator * (const cpx &a,const cpx &b){ return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x}; } cpx operator / (const cpx &a,const cpx &b){ return {(a.x*b.x+a.y*b.y)/(b.x*b.x+b.y*b.y),(a.y*b.x-a.x*b.y)/(b.x*b.x+b.y*b.y)}; } const db PI=acos(-1.0); int to[N]; void fft(int len,cpx *a){ rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1)); rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]); for(int k=2;k<=len;k<<=1){ cpx w={cos(PI*2.0/k),sin(PI*2.0/k)}; for(int i=0;i<len;i+=k){ cpx x={1,0}; for(int j=0;j<(k>>1);j++){ cpx p=a[i+j],q=a[i+j+(k>>1)]*x; a[i+j]=p+q,a[i+j+(k>>1)]=p-q; x=x*w; } } } } void ifft(int len,cpx *a){ for(int k=len;k>=2;k>>=1){ cpx w={cos(PI*2.0/k),sin(PI*2.0/k)}; for(int i=0;i<len;i+=k){ cpx x={1,0}; for(int j=0;j<(k>>1);j++){ cpx p=a[i+j],q=a[i+j+(k>>1)]; a[i+j]=(p+q)/(cpx){2,0},a[i+j+(k>>1)]=(p-q)/(cpx){2,0}/x; x=x*w; } } } rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1)); rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]); } cpx fft_a[N]; void convolution_fft(int n,ll *A,int m,ll *B,ll *C){ int len=1; while(len<=(n+m))len<<=1; rep(i,0,len-1)fft_a[i]={(db)A[i],(db)B[i]}; fft(len,fft_a); rep(i,0,len-1)fft_a[i]=fft_a[i]*fft_a[i]; ifft(len,fft_a); rep(i,0,len-1)C[i]=(ll)round(fft_a[i].y/2.0); } const ll mod=998244353; const ll I2=(mod+1)/2; const ll G=3; ll ksm(ll a,ll b,ll p){ a=a%p; ll r=1; while(b){ if(b&1)r=r*a%p; a=a*a%p; b>>=1; } return r%p; } const ll IG=ksm(G,mod-2,mod); void ntt(int len,ll *a){ rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1)); rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]); for(int k=2;k<=len;k<<=1){ ll w=ksm(G,(mod-1)/k,mod); for(int i=0;i<len;i+=k){ ll x=1; for(int j=0;j<(k>>1);j++){ ll p=a[i+j],q=a[i+j+(k>>1)]*x%mod; a[i+j]=(p+q)%mod,a[i+j+(k>>1)]=(p-q+mod)%mod; x=x*w%mod; } } } } void intt(int len,ll *a){ for(int k=len;k>=2;k>>=1){ ll w=ksm(IG,(mod-1)/k,mod); for(int i=0;i<len;i+=k){ ll x=1; for(int j=0;j<(k>>1);j++){ ll p=a[i+j],q=a[i+j+(k>>1)]; a[i+j]=(p+q)*I2%mod,a[i+j+(k>>1)]=(p-q+mod)*I2%mod*x%mod; x=x*w%mod; } } } rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1)); rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]); } ll ntt_a[N],ntt_b[N],ntt_c[N]; void convolution_ntt(int n,ll *A,int m,ll *B,ll *C){ int len=1; while(len<=(n+m))len<<=1; rep(i,0,len-1)ntt_a[i]=A[i]; rep(i,0,len-1)ntt_b[i]=B[i]; ntt(len,ntt_a); ntt(len,ntt_b); rep(i,0,len-1)ntt_c[i]=ntt_a[i]*ntt_b[i]%mod; intt(len,ntt_c); rep(i,0,len-1)C[i]=ntt_c[i]; } }; int n,m; ll a[N],b[N],c[N]; void slv(int _csid,int _csi){ scanf("%d %d",&n,&m); rep(i,0,n)scanf("%lld",&a[i]); rep(i,0,m)scanf("%lld",&b[i]); // Bpoly::convolution_fft(n,a,m,b,c); Bpoly::convolution_ntt(n,a,m,b,c); rep(i,0,n+m)printf("%lld ",c[i]); } void main(){ // ios::sync_with_stdio(0),cin.tie(0),cout.tie(0); int T=1,csid=0; // scanf("%d",&csid); // scanf("%d",&T); rep(i,1,T)slv(csid,i); } } int main(){ string __name=""; if(__name!=""){ freopen((__name+".in").c_str(),"r",stdin); freopen((__name+".out").c_str(),"w",stdout); } ax_by_c::main(); return 0; } ```