高精度乘法从入门到入土

· · 算法·理论

题解看不到的神奇优化,好写好记(确信)的代码和公式,尽在此篇!

update 2024.3.22 添加了几种之前没有提到的 FFT 优化方式。

update 2024.3.25 添加了关于 DIF-FFT 的说明和使用。

update 2024.3.28 去除了错误的描述,添加了分裂基和效率对比。

update 2024.11.6 更正了 FFT 的英文名。

Karatsuba 分治乘法

(注:这一部分并不实用,只是前置知识少,门槛低。)

x,y 两数的位数为 n,则正常情况下进行乘法需要把 x 的每一位分别乘以 y 的每一位,故时间复杂度为 O(n^2),运行慢,我们可以利用分治思想优化(下面的所有形如 \dfrac{n}{k} 的数均向下取整):

x=a\times10^{\frac{n}{2}}+b,y=c\times10^{\frac{n}{2}}+d,则:

xy&=ac\times10^{\frac{n}{2}\times2}+(ad+bc)\times10^{\frac{n}{2}}+bd \\&=ac\times10^{\frac{n}{2}\times2}+[(a+b)\times(c+d)-ac-bd]\times10^{\frac{n}{2}}+bd \end{aligned}

。所以我们只需要处理 (a+b)\times(c+d),ac,bd 三组长度为 \dfrac{n}{2} 的数相乘,总计算量为 \dfrac{3}{4}n^2+O(n),继续用同样的方法计算 (a+b)\times(c+d),ac,bd,就能优化更多,设数据规模为 n 时运算量为 T(n),则

T(n)=\begin{cases}O(1)&(n\le1)\\3T(\frac{n}{2})+O(n)&(n>1)\end{cases}

由主定理分析可知复杂度为 O(n^{\log_23})

参考代码(丑陋):

using ll=long long;
#define base 10
//使用 vector 存储数字,并写了几个相关函数,详见完整代码。
    ulllint operator*(const ulllint& x,const ulllint& y){
        if(x.size()==1||y.size()==1){
            ulllint c(x.size()+y.size()+1,0);
            for(size_t i=0;i<x.size();i++){
                for(size_t j=0;j<y.size();j++){
                    c[i+j]+=x[i]*y[j];
                }
            }
            for(size_t i=0;i<c.size()-1;i++){
                c[i+1]+=c[i]/ulllint::base;
                c[i]%=ulllint::base;
            }
            while(c.size()>1&&c[c.size()-1]==0)c.nums.pop_back();
            return c;
        }
        size_t len=std::max(x.size(),y.size())>>1;
        ulllint a(x.begin()+len,x.end()),b(x.begin(),x.begin()+len),c(y.begin()+len,y.end()),d(y.begin(),y.begin()+len);
        ulllint ac(a*c),bd(b*d);
        ulllint ad_bc((a+b)*(c+d)-ac-bd);
        std::vector<ulllint::ll>_0(len,0);
        ac.nums.insert(ac.nums.begin(),_0.begin(),_0.end());
        ac.nums.insert(ac.nums.begin(),_0.begin(),_0.end());
        ad_bc.nums.insert(ad_bc.begin(),_0.begin(),_0.end());
        return ac+ad_bc+bd;
    }

优化

如果你试过运行上述代码,就会发现它的效率实际很低,这是因为在实际实现中,大量的拆分使这份代码带上了巨大的常数,一个简单有效的优化就是当 x,y 的位数小于一定值时就直接改用暴力算法,而不是直接拆到不能再拆。只需将

if(x.size()==1&&x.size()==1)

改为

//k 为你设定的某常数,本题我取了 65。
if(x.size()<k&&x.size()<k)

即可。

还没结束!我们还可以采取一个优化:long long 型的储存范围为 2^{63}-1>9\times10^{18},只存 1 位太可惜了,我们可以用一个元素存 8 位!只需将上述代码中 base 的值改为 10^8 即可。

至此,我们的代码已经可以轻松通过 P1919:

#include<cstdio>
#include<cstring>
#include<cmath>
#include<vector>
#include<iostream>
#include<iomanip>
namespace Bigint{
    const size_t MAX_SIZE=1000000;
    char temp[MAX_SIZE+1];
    class ulllint{
        friend std::istream;
        friend std::ostream;
        using ll=long long;
        using size_t=unsigned int;
        static constexpr size_t len=8;
        static constexpr size_t base=100000000;
        std::vector<ll>nums;
        inline std::vector<ll>::const_iterator begin()const{return nums.begin();}
        inline std::vector<ll>::const_iterator end()const{return nums.end();}
        inline size_t size()const{return nums.size();}
        inline ll& operator[](const size_t& x){return nums[x];}
        inline const ll& operator[](const size_t& x)const{return nums[x];}
        ulllint(const std::vector<ll>::const_iterator& first,const std::vector<ll>::const_iterator& last){
            nums=std::vector<ll>(first,last);
        }
        ulllint(const size_t& x,const ll& y){
            nums=std::vector<ll>(x,y);
        }
        ulllint(const std::vector<ll>&x){nums=x;}
    public:
        ulllint(const size_t& x=0){
            nums=std::vector<ll>(1,x);
        }
        ulllint(const ulllint& x){
            nums=x.nums;
        }
        ulllint& operator=(const ulllint& x){
            nums=x.nums;
            return *this;
        }
        friend std::istream& operator>>(std::istream& cin,ulllint& x);
        friend std::ostream& operator<<(std::ostream& cout,const ulllint& x);
        friend ulllint operator+(const ulllint& x,const ulllint& y);
        friend ulllint operator-(const ulllint& x,const ulllint& y);
        friend ulllint operator*(const ulllint& x,const ulllint& y);
    };
    std::istream& operator>>(std::istream& cin,ulllint& x){
        std::cin>>temp;
        const size_t l=strlen(temp);
        size_t j=-1,w=1;
        x.nums=std::vector<long long>(l/ulllint::len+2,0);
        for(size_t i=0;i<l;i++){
            if(!(i%ulllint::len))j++,w=1;
            x[j]+=(temp[l-i-1]^'0')*w;
            w*=10;
        }
        while(x.size()&&x[x.size()-1]==0)x.nums.pop_back();
        return cin;
    }
    std::ostream& operator<<(std::ostream& cout,const ulllint& x){
        if(x.size()==0)return cout;
        cout<<x[x.size()-1];
        for(signed int i=((int)x.size())-2;i>=0;i--){
            cout<<std::setfill('0')<<std::setw(ulllint::len)<<x[i];
        }
        return cout;
    }
    ulllint operator+(const ulllint& x,const ulllint& y){
        if(x.size()==0)return y;
        if(y.size()==0)return x;
        ulllint c(std::max(x.size(),y.size())+1,0);
        for(size_t i=0;i<x.size();i++)c[i]+=x[i];
        for(size_t i=0;i<y.size();i++)c[i]+=y[i];
        for(size_t i=0;i<c.size()-1;i++)c[i+1]+=c[i]/ulllint::base,c[i]%=ulllint::base;
        while(c.size()>1&&c[c.size()-1]==0)c.nums.pop_back();
        return c;
    }
    ulllint operator-(const ulllint& x,const ulllint& y){
        if(y.size()==0)return x;
        ulllint z(x);
        for(size_t i=0;i<y.size();i++){
            if(z[i]<y[i]){
                z[i]+=ulllint::base;
                z[i+1]--;
            }
            z[i]-=y[i];
        }
        while(z.size()>1&&z[z.size()-1]==0)z.nums.pop_back();
        return z;
    }
    ulllint operator*(const ulllint& x,const ulllint& y){
        if(x.size()<65||y.size()<65){
            ulllint c(x.size()+y.size()+1,0);
            for(size_t i=0;i<x.size();i++){
                for(size_t j=0;j<y.size();j++){
                    c[i+j]+=x[i]*y[j];
                }
            }
            for(size_t i=0;i<c.size()-1;i++){
                c[i+1]+=c[i]/ulllint::base;
                c[i]%=ulllint::base;
            }
            while(c.size()>1&&c[c.size()-1]==0)c.nums.pop_back();
            return c;
        }
        size_t len=std::max(x.size(),y.size())>>1;
        ulllint a(x.begin()+len,x.end()),b(x.begin(),x.begin()+len),c(y.begin()+len,y.end()),d(y.begin(),y.begin()+len);
        ulllint ac(a*c),bd(b*d);
        ulllint ad_bc((a+b)*(c+d)-ac-bd);
        std::vector<ulllint::ll>_0(len,0);
        ac.nums.insert(ac.nums.begin(),_0.begin(),_0.end());
        ac.nums.insert(ac.nums.begin(),_0.begin(),_0.end());
        ad_bc.nums.insert(ad_bc.begin(),_0.begin(),_0.end());
        return ac+ad_bc+bd;
    }
}
using namespace Bigint;
ulllint a,b;
int main(){
    std::iostream::sync_with_stdio(0);
    std::cin.tie(0);
    std::cin>>a>>b;
    std::cout<<a*b;
    return 0;
}

快速傅里叶变换(Fast Fouier Transform,FFT)

前置知识

正文

(为了可读性,本篇大部分代码均按照公式书写,而不进行过度卡常或依据编程语言的规则进行优化,a=a+b 等类似写法也会频繁出现。)

快速傅里叶变换可以快速(O(n\log n))计算多项式在 xn+1 个单位根的幂时的值,以此做到快速将多项式转化为点值表示来快速(O(n))相乘。

f(x)=\sum_{i=0}^{n-1}a_ix^i,其中 n2 的整数次幂(原因等下会讲),则:

\begin{aligned} f(x)&=\sum_{i=0}^{n-1}a_ix^i \\&=(\sum_{i=0}^{\frac{n}{2}-1}a_{2i}x^{2i})+(\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}x^{2i+1}) \\&=(\sum_{i=0}^{\frac{n}{2}-1}a_{2i}x^{2i})+x(\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}x^{2i}) \end{aligned}

令:

g(x)=\sum_{i=0}^{\frac{n}{2}-1}a_{2i}x^{i} h(x)=\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}x^{i}

则:

f(x)=g(x^2)+x\times h(x^2)

w_n^kw_n^{k+\frac{n}{2}} 分别代入,由单位根的性质得:

f(\omega_n^k)=g(\omega_{\frac{n}{2}}^{k})+\omega_n^kh(\omega_{\frac{n}{2}}^k) f(\omega_n^{k+\frac{n}{2}})=g(\omega_{\frac{n}{2}}^{k})-\omega_n^kh(w_{\frac{n}{2}}^k)

这样我们只需要代入一半的单位根的幂就可以顺带求出另一半的点值啦!对于 h(x) g(x) 显然可以递归地继续用同样的方法,因为通过一半求出另一半需要每次都严格的将多项式分成相等长度的两部分,所以 n 必须为 2 的整次幂。

蝶形变换

实现中,反复地分多项式是很费时间的,我们可以在算法开始前就把每个系数放到最后到达的位置,再递归计算,这样就避免了反复复制。

下面这段解释取自 OI wiki。

规律:其实就是原来的那个序列,每个数用二进制表示,然后把二进制翻转对称一下,就是最终那个位置的下标,我们称这个变换为位逆序置换(bit-reversal permutation)。

很多题解在说完这一点后就直接得出了迭代 FFT,但实际上用这种方法优化完后递归版一样跑的很快。

代码

using cp=complex<double>;
const double pi=acos(-1);
  //complex 是 C++ 自带的复数类
  //位逆序置换
void rev(const int& n,cp a[]){
    for(int i=0,j=0;i<n;i++){
        if(i>j)swap(a[i],a[j]);
        for(int l=(n>>1);(j^=l)<l;l>>=1);
    }
}
  //FFT,用 a[i] 储存代入 w_n^i 时的点值
void fft(cp a[],const int& n){
    if(n<=1)return;
    const int half=n>>1;
    fft(a,half);
    fft(a+half,half);
    cp w(1,0),wn(cos(2*pi/n),sin(2*pi/n));
    for(int i=0;i<half;i++){
        cp x=a[i],y=w*a[i+half];
        a[i]=x+y;
        a[i+half]=x-y;
        w*=wn;
    }
}

优化

(a+bi)^2=a^2-b^2+2abi

故我们可以将两个多项式分别放在同一多项式的实部和虚部,FFT 后再计算结果的平方,虚部的一半就是答案。

参考代码:


template<const int n>
void fft(cp a[]){
    if(n<=1)return;
    const int half=n>>1;
    fft<half>(a);
    fft<half>(a+half);
    cp w(1,0),wn(cos(2*pi/n),sin(2*pi/n));
    for(int i=0;i<half;i++){
        cp x=a[i],y=w*a[i+half];
        a[i]=x+y;
        a[i+half]=x-y;
        w*=wn;
    }
}
//特化递归尽头
template<>
void fft<1>(cp a[]){}
template<>
void fft<0>(cp a[]){}
//模板参数只能是常数,需要一一特判,这也是为什么其它递归算法不这样做的原因
void runfft(cp a[],const int& n){
    rev(n,a);
    switch(n){
    case 1<<1:fft<1<<1>(a);break;
    case 1<<2:fft<1<<2>(a);break;
    case 1<<3:fft<1<<3>(a);break;
    case 1<<4:fft<1<<4>(a);break;
    case 1<<5:fft<1<<5>(a);break;
    case 1<<6:fft<1<<6>(a);break;
    case 1<<7:fft<1<<7>(a);break;
    case 1<<8:fft<1<<8>(a);break;
    case 1<<9:fft<1<<9>(a);break;
    case 1<<10:fft<1<<10>(a);break;
    case 1<<11:fft<1<<11>(a);break;
    case 1<<12:fft<1<<12>(a);break;
    case 1<<13:fft<1<<13>(a);break;
    case 1<<14:fft<1<<14>(a);break;
    case 1<<15:fft<1<<15>(a);break;
    case 1<<16:fft<1<<16>(a);break;
    case 1<<17:fft<1<<17>(a);break;
    case 1<<18:fft<1<<18>(a);break;
    case 1<<19:fft<1<<19>(a);break;
    case 1<<20:fft<1<<20>(a);break;
    case 1<<21:fft<1<<21>(a);break;
    }
}

频率抽取(DIF)

刚才的 FFT 是按照下标奇偶性分的,其实也可以按前后分,令 k 为偶数:

\begin{aligned} f(x)&=\sum_{i=0}^{n-1}a_ix^i\\ &=(\sum_{i=0}^{\frac{n}{2}-1}a_ix^i)+\sum_{i=0}^{\frac{n}{2}-1}a_{i+\frac{n}{2}}x^{i+\frac{n}{2}} \\&=(\sum_{i=0}^{\frac{n}{2}-1}a_ix^i)+x^{\frac{n}{2}}\sum_{i=0}^{\frac{n}{2}-1}a_{i+\frac{n}{2}}x^{i} \\ f(\omega_n^k)&=(\sum_{i=0}^{\frac{n}{2}-1}a_i\omega_n^{ik})+\omega_{n}^{\frac{nk}{2}}\sum_{i=0}^{\frac{n}{2}-1}a_{i+\frac{n}{2}}\omega_n^{ik} \\&=\sum_{i=0}^{\frac{n}{2}-1}\omega_n^{ik}(a_i+a_{i+\frac{n}{2}}) \\ f(\omega_n^{k+1})&=\sum_{i=0}^{\frac{n}{2}-1}\omega_n^{i(k+1)}(a_i-a_{i+\frac{n}{2}}) \\&=\sum_{i=0}^{\frac{n}{2}-1}\omega_n^{ik}(a_i-a_{i+\frac{n}{2}})\omega_n^i \end{aligned}

刚才的 FFT(时间抽取,DIT)是先递归计算,再进行合并,而现在,我们可以先直接计算每组 (a_i-a_{i+\frac{n}{2}})\omega_n^i(a_i+a_{i+\frac{n}{2}}) 的值,再递归地计算两部分,稍微画蝶形图分析一下,我们这样计算可以直接传入原数组,而得到的结果是二进制逆序的,参考代码:

//DIF-FFT
template<const int n>
void fft(cp a[]){
    if(n<=1)return;
    const int half=n>>1;
    cp w(1,0),wn(cos(pi2/n),sin(pi2/n));
    for(int i=0;i<half;i++){
        cp x=a[i],y=a[i+half];
        a[i]=x+y;
        a[i+half]=(x-y)*w;
        w*=wn;
    }
    fft<half>(a);
    fft<half>(a+half);
}
template<>
void fft<1>(cp a[]){}
template<>
void fft<0>(cp a[]){}
void runfft(cp a[],const int& n){
    //直接开始
    switch(n){
        case 1<<1:fft<1<<1>(a);break;
        case 1<<2:fft<1<<2>(a);break;
        case 1<<3:fft<1<<3>(a);break;
        case 1<<4:fft<1<<4>(a);break;
        case 1<<5:fft<1<<5>(a);break;
        case 1<<6:fft<1<<6>(a);break;
        case 1<<7:fft<1<<7>(a);break;
        case 1<<8:fft<1<<8>(a);break;
        case 1<<9:fft<1<<9>(a);break;
        case 1<<10:fft<1<<10>(a);break;
        case 1<<11:fft<1<<11>(a);break;
        case 1<<12:fft<1<<12>(a);break;
        case 1<<13:fft<1<<13>(a);break;
        case 1<<14:fft<1<<14>(a);break;
        case 1<<15:fft<1<<15>(a);break;
        case 1<<16:fft<1<<16>(a);break;
        case 1<<17:fft<1<<17>(a);break;
        case 1<<18:fft<1<<18>(a);break;
        case 1<<19:fft<1<<19>(a);break;
        case 1<<20:fft<1<<20>(a);break;
        case 1<<21:fft<1<<21>(a);break;
    }
    //收尾再二进制逆序置换
    rev(n,a);
}

这一方法的使用将在快速傅里叶逆变换中提到。

分裂基-DIT

以上的 FFT 是单次将多项式分成 2 份的,我们可以称之为基 2 FFT,相应的,还有基 4,基 8 甚至基 16 等多种实现方法。

而分裂基 FFT 就是其中一种,它是将基 2 FFT 中下标为奇数的部分再次拆分(似乎可以算是不太严格的基 3 FFT?),即在:

f(x)=g(x^2)+x\times h(x^2)

的基础上,再用同种方法使:

h(x)=h_1(x^2)+x\times h_2(x^2)

代入原式,得到:

f(x)=g(x^2)+x\times h_1(x^4)+x^3\times h_2(x^4)

这样分,同时兼具基 2 FFT 的灵活性和单次拆分多部分的高效率。

参考代码:

//DIT-FFT,分裂基
template<const int n>
void fft(cp a[]){
    constexpr int half=n>>1,quarter=n>>2;
    ifft<half>(a);
    ifft<quarter>(a+half);
    ifft<quarter>(a+half+quarter);
    cp w(1,0),wn(cos(pi2/n),-sin(pi2/n));
    for(int i=0;i<quarter;i++){
        cp w3=w*w*w;
        cp tmp1=w*a[i+half],tmp2=w3*a[i+half+quarter];
        cp x=a[i],y=tmp1+tmp2;
        cp x1=a[i+quarter],y1=tmp1-tmp2;
        y1=cp(y1.imag(),-y1.real());
        a[i]=x+y;
        a[i+quarter]=x1+y1;
        a[i+half]=x-y;
        a[i+half+quarter]=x1-y1;
        w*=wn;
    }
}
template<>
void fft<2>(cp a[]){
    cp x=a[0],y=a[1];
    a[0]=x+y;
    a[1]=x-y;
}
template<>
void fft<1>(cp a[]){}
template<>
void fft<0>(cp a[]){}

分裂基-DIF

(为了公式,代码与介绍一致且便于叙述,下面称原下标偶数部分为左半,奇数部分为右半。)

思想同上,还是把 f(x) 写成一个基 2 FFT 加两个基 4 FFT 的形式即可。

这里需要注意:分裂基 DIT 在递归完后,因为左右都处理完了,而右半采取的是基 4 FFT 的拆分,所以用类似基 4 FFT 的方式合并,而分裂基 DIF 在拆分之前,因为左半和右半都需要继续递归,所以左半按基 2 拆,右半按基 4 拆。

参考代码:


template<const int n>
void fft(cp a[]){
    const int half=n>>1,quarter=n>>2;
    cp w(1,0),wn(cos(pi2/n),sin(pi2/n));
    for(int i=0;i<quarter;i++){
        cp w3=w*w*w;
        cp x=a[i]-a[i+half],y=a[i+quarter]-a[i+half+quarter];
        y=cp(y.imag(),-y.real());
        a[i]=a[i]+a[i+half];
        a[i+quarter]=a[i+quarter]+a[i+half+quarter];
        a[i+half]=(x-y)*w;
        a[i+half+quarter]=(x+y)*w3;
        w*=wn;
    }
    fft<half>(a);
    fft<quarter>(a+half);
    fft<quarter>(a+half+quarter);
}
template<>
void fft<2>(cp a[]){
    cp x=a[0],y=a[1];
    a[0]=x+y;
    a[1]=x-y;
}
template<>
void fft<1>(cp a[]){}
template<>
void fft<0>(cp a[]){}

快速数论变换(NTT)

前置知识

正文

对于质数 p 的原根 g,g^{\frac{p-1}{n}} 在模意义下与 n 次单位根有相同的性质,所以可以用原根的幂代替复数单位根代入进行模意义下的 FFT,以此规避浮点数计算并加快速度。

为了 \dfrac{p-1}{n} 为整数,我们的 p 应取形如 q\times2^k+1 的质数,比如 4179340454199820289=29\times2^{57}+1,998244353=119\times2^{23}+1 等,这两个的原根都是 3

如果选择 q\times2^k+1 作模数,则多项式长度不能超过 2^k

如果原多项式的各项的系数不超过 x,长度为 n,则相乘后各项系数不超过 x^2n,注意开 long long 和快速乘,同时选择合适的模数。

优化

代码

using ll=long long;
const int N=(1<<21)+1;
const ll p=119*(1LL<<23)+1,g=3;
ll powg[N];
#define qmul(x,y) (x*y%p)
//历史遗留问题,无意义甚至更差

//快速幂
ll qpow(ll a,ll b){
    ll ans=1;
    while(b){
        if(b&1)ans=qmul(ans,a);
        a=qmul(a,a);
        b>>=1;
    }
    return ans;
}

//预处理
void init(const int& n){
    ll ig=qpow(g,p-2);
    for(int i=1;i<=n;i<<=1)powg[i]=qpow(g,(p-1)/i);
    for(int i=1;i<=n;i<<=1)ipowg[i]=qpow(ig,(p-1)/i);
}

void ntt2(ll a[],ll b[],const int& n){
    if(n<=1)return;
    const int half=n/2;
    ntt2(a,b,half);
    ntt2(a+half,b+half,half);
    ll w=1,wn=powg[n];
    for(int i=0;i<half;i++){
        ll x,y;
        x=a[i];
        y=qmul(a[i+half],w);
        a[i]=x+y;
        if(a[i]>=p)a[i]-=p;
        a[i+half]=(x-y+p)%p;
        x=b[i];
        y=qmul(b[i+half],w);
        b[i]=x+y;
        if(b[i]>=p)b[i]-=p;
        b[i+half]=(x-y+p)%p;
        w=qmul(w,wn);
    }
}

快速傅里叶逆变换/快速数论逆变换(IFFT/INTT)

笔者才疏学浅,这里直接给出结论:用相乘后的各个点值作系数构成新多项式,用单位根的倒数(或原根的逆元)代替单位根(或原根)进行 FFT/NTT,得到的各项点值再除以多项式的长度(或乘以长度的逆元)就是各项系数。

DIF 和 DIT 混合使用加速

DIF 需要最后进行二进制逆序置换,DIT 需要开始时二进制逆序置换,我们用 DIF 的 FFT,DIT 的 IFFT 就相当于两个都不需要二进制逆序置换了代码又少写一点。

极端优化

前排提醒,这一部分的内容实用性远低于前面各部分,且不给出任何形式的实现,仅供大佬们参考。

效率对比

(以洛谷 P3803 为测试题目,均采用快读和三次变两次优化。)

省流:纯递归最慢,迭代次之,模板递归最快,DIF 和 DIT 混用优化和分裂基都很有较大提升,总计时间较最慢的写法(DIT,递归,基 2)减少了约 40\%

参考文献

特别鸣谢