高精度乘法从入门到入土
zhangbo1000 · · 算法·理论
题解看不到的神奇优化,好写好记(确信)的代码和公式,尽在此篇!
update 2024.3.22 添加了几种之前没有提到的 FFT 优化方式。
update 2024.3.25 添加了关于 DIF-FFT 的说明和使用。
update 2024.3.28 去除了错误的描述,添加了分裂基和效率对比。
update 2024.11.6 更正了 FFT 的英文名。
Karatsuba 分治乘法
(注:这一部分并不实用,只是前置知识少,门槛低。)
设
设
。所以我们只需要处理
由主定理分析可知复杂度为
参考代码(丑陋):
using ll=long long;
#define base 10
//使用 vector 存储数字,并写了几个相关函数,详见完整代码。
ulllint operator*(const ulllint& x,const ulllint& y){
if(x.size()==1||y.size()==1){
ulllint c(x.size()+y.size()+1,0);
for(size_t i=0;i<x.size();i++){
for(size_t j=0;j<y.size();j++){
c[i+j]+=x[i]*y[j];
}
}
for(size_t i=0;i<c.size()-1;i++){
c[i+1]+=c[i]/ulllint::base;
c[i]%=ulllint::base;
}
while(c.size()>1&&c[c.size()-1]==0)c.nums.pop_back();
return c;
}
size_t len=std::max(x.size(),y.size())>>1;
ulllint a(x.begin()+len,x.end()),b(x.begin(),x.begin()+len),c(y.begin()+len,y.end()),d(y.begin(),y.begin()+len);
ulllint ac(a*c),bd(b*d);
ulllint ad_bc((a+b)*(c+d)-ac-bd);
std::vector<ulllint::ll>_0(len,0);
ac.nums.insert(ac.nums.begin(),_0.begin(),_0.end());
ac.nums.insert(ac.nums.begin(),_0.begin(),_0.end());
ad_bc.nums.insert(ad_bc.begin(),_0.begin(),_0.end());
return ac+ad_bc+bd;
}
优化
如果你试过运行上述代码,就会发现它的效率实际很低,这是因为在实际实现中,大量的拆分使这份代码带上了巨大的常数,一个简单有效的优化就是当
if(x.size()==1&&x.size()==1)
改为
//k 为你设定的某常数,本题我取了 65。
if(x.size()<k&&x.size()<k)
即可。
还没结束!我们还可以采取一个优化:long long 型的储存范围为
至此,我们的代码已经可以轻松通过 P1919:
#include<cstdio>
#include<cstring>
#include<cmath>
#include<vector>
#include<iostream>
#include<iomanip>
namespace Bigint{
const size_t MAX_SIZE=1000000;
char temp[MAX_SIZE+1];
class ulllint{
friend std::istream;
friend std::ostream;
using ll=long long;
using size_t=unsigned int;
static constexpr size_t len=8;
static constexpr size_t base=100000000;
std::vector<ll>nums;
inline std::vector<ll>::const_iterator begin()const{return nums.begin();}
inline std::vector<ll>::const_iterator end()const{return nums.end();}
inline size_t size()const{return nums.size();}
inline ll& operator[](const size_t& x){return nums[x];}
inline const ll& operator[](const size_t& x)const{return nums[x];}
ulllint(const std::vector<ll>::const_iterator& first,const std::vector<ll>::const_iterator& last){
nums=std::vector<ll>(first,last);
}
ulllint(const size_t& x,const ll& y){
nums=std::vector<ll>(x,y);
}
ulllint(const std::vector<ll>&x){nums=x;}
public:
ulllint(const size_t& x=0){
nums=std::vector<ll>(1,x);
}
ulllint(const ulllint& x){
nums=x.nums;
}
ulllint& operator=(const ulllint& x){
nums=x.nums;
return *this;
}
friend std::istream& operator>>(std::istream& cin,ulllint& x);
friend std::ostream& operator<<(std::ostream& cout,const ulllint& x);
friend ulllint operator+(const ulllint& x,const ulllint& y);
friend ulllint operator-(const ulllint& x,const ulllint& y);
friend ulllint operator*(const ulllint& x,const ulllint& y);
};
std::istream& operator>>(std::istream& cin,ulllint& x){
std::cin>>temp;
const size_t l=strlen(temp);
size_t j=-1,w=1;
x.nums=std::vector<long long>(l/ulllint::len+2,0);
for(size_t i=0;i<l;i++){
if(!(i%ulllint::len))j++,w=1;
x[j]+=(temp[l-i-1]^'0')*w;
w*=10;
}
while(x.size()&&x[x.size()-1]==0)x.nums.pop_back();
return cin;
}
std::ostream& operator<<(std::ostream& cout,const ulllint& x){
if(x.size()==0)return cout;
cout<<x[x.size()-1];
for(signed int i=((int)x.size())-2;i>=0;i--){
cout<<std::setfill('0')<<std::setw(ulllint::len)<<x[i];
}
return cout;
}
ulllint operator+(const ulllint& x,const ulllint& y){
if(x.size()==0)return y;
if(y.size()==0)return x;
ulllint c(std::max(x.size(),y.size())+1,0);
for(size_t i=0;i<x.size();i++)c[i]+=x[i];
for(size_t i=0;i<y.size();i++)c[i]+=y[i];
for(size_t i=0;i<c.size()-1;i++)c[i+1]+=c[i]/ulllint::base,c[i]%=ulllint::base;
while(c.size()>1&&c[c.size()-1]==0)c.nums.pop_back();
return c;
}
ulllint operator-(const ulllint& x,const ulllint& y){
if(y.size()==0)return x;
ulllint z(x);
for(size_t i=0;i<y.size();i++){
if(z[i]<y[i]){
z[i]+=ulllint::base;
z[i+1]--;
}
z[i]-=y[i];
}
while(z.size()>1&&z[z.size()-1]==0)z.nums.pop_back();
return z;
}
ulllint operator*(const ulllint& x,const ulllint& y){
if(x.size()<65||y.size()<65){
ulllint c(x.size()+y.size()+1,0);
for(size_t i=0;i<x.size();i++){
for(size_t j=0;j<y.size();j++){
c[i+j]+=x[i]*y[j];
}
}
for(size_t i=0;i<c.size()-1;i++){
c[i+1]+=c[i]/ulllint::base;
c[i]%=ulllint::base;
}
while(c.size()>1&&c[c.size()-1]==0)c.nums.pop_back();
return c;
}
size_t len=std::max(x.size(),y.size())>>1;
ulllint a(x.begin()+len,x.end()),b(x.begin(),x.begin()+len),c(y.begin()+len,y.end()),d(y.begin(),y.begin()+len);
ulllint ac(a*c),bd(b*d);
ulllint ad_bc((a+b)*(c+d)-ac-bd);
std::vector<ulllint::ll>_0(len,0);
ac.nums.insert(ac.nums.begin(),_0.begin(),_0.end());
ac.nums.insert(ac.nums.begin(),_0.begin(),_0.end());
ad_bc.nums.insert(ad_bc.begin(),_0.begin(),_0.end());
return ac+ad_bc+bd;
}
}
using namespace Bigint;
ulllint a,b;
int main(){
std::iostream::sync_with_stdio(0);
std::cin.tie(0);
std::cin>>a>>b;
std::cout<<a*b;
return 0;
}
快速傅里叶变换(Fast Fouier Transform,FFT)
前置知识
-
-
复数:指形如
a+bi 的数,其中i=\sqrt{-1} ,a,b 为实数。-
单位根:如果有一个数
\omega_n ,满足\omega_n^n=1 ,称这个数为n 次单位根。 -
-
主
n 次单位根的性质:-
\omega_n^0=\omega_n^n=1 -
\omega_n^k=-\omega_n^{k+\frac{n}{2}} -
\omega_n^k=\omega_{2n}^{2k} -
(上两条的推论)
\omega_{2n}^{k}=-\omega_{2n}^{k+n}
-
-
-
多项式的点值表示:对于一个一元
n 次多项式f(x)=\sum_{i=0}^{n}a_ix^i ,如果能知道它在x 为n+1 个不同的值时的取值,就可以确定唯一一个多项式。- 显然,在点值表示法下,求两个多项式乘积只需要把相同
n 个x 值时多项式的值分别相乘即可。
- 显然,在点值表示法下,求两个多项式乘积只需要把相同
正文
(为了可读性,本篇大部分代码均按照公式书写,而不进行过度卡常或依据编程语言的规则进行优化,a=a+b 等类似写法也会频繁出现。)
快速傅里叶变换可以快速(
令
令:
则:
将
这样我们只需要代入一半的单位根的幂就可以顺带求出另一半的点值啦!对于
蝶形变换
实现中,反复地分多项式是很费时间的,我们可以在算法开始前就把每个系数放到最后到达的位置,再递归计算,这样就避免了反复复制。
下面这段解释取自 OI wiki。
规律:其实就是原来的那个序列,每个数用二进制表示,然后把二进制翻转对称一下,就是最终那个位置的下标,我们称这个变换为位逆序置换(bit-reversal permutation)。
很多题解在说完这一点后就直接得出了迭代 FFT,但实际上用这种方法优化完后递归版一样跑的很快。
代码
using cp=complex<double>;
const double pi=acos(-1);
//complex 是 C++ 自带的复数类
//位逆序置换
void rev(const int& n,cp a[]){
for(int i=0,j=0;i<n;i++){
if(i>j)swap(a[i],a[j]);
for(int l=(n>>1);(j^=l)<l;l>>=1);
}
}
//FFT,用 a[i] 储存代入 w_n^i 时的点值
void fft(cp a[],const int& n){
if(n<=1)return;
const int half=n>>1;
fft(a,half);
fft(a+half,half);
cp w(1,0),wn(cos(2*pi/n),sin(2*pi/n));
for(int i=0;i<half;i++){
cp x=a[i],y=w*a[i+half];
a[i]=x+y;
a[i+half]=x-y;
w*=wn;
}
}
优化
- 众所周知:
故我们可以将两个多项式分别放在同一多项式的实部和虚部,FFT 后再计算结果的平方,虚部的一半就是答案。
-
STL 的
complex在不开 O2 的情况下常数略大于手写。 -
模板递归!将多项式长度作为模板参数传入,使编译器可以在编译阶段确定其值并内联展开。
参考代码:
template<const int n>
void fft(cp a[]){
if(n<=1)return;
const int half=n>>1;
fft<half>(a);
fft<half>(a+half);
cp w(1,0),wn(cos(2*pi/n),sin(2*pi/n));
for(int i=0;i<half;i++){
cp x=a[i],y=w*a[i+half];
a[i]=x+y;
a[i+half]=x-y;
w*=wn;
}
}
//特化递归尽头
template<>
void fft<1>(cp a[]){}
template<>
void fft<0>(cp a[]){}
//模板参数只能是常数,需要一一特判,这也是为什么其它递归算法不这样做的原因
void runfft(cp a[],const int& n){
rev(n,a);
switch(n){
case 1<<1:fft<1<<1>(a);break;
case 1<<2:fft<1<<2>(a);break;
case 1<<3:fft<1<<3>(a);break;
case 1<<4:fft<1<<4>(a);break;
case 1<<5:fft<1<<5>(a);break;
case 1<<6:fft<1<<6>(a);break;
case 1<<7:fft<1<<7>(a);break;
case 1<<8:fft<1<<8>(a);break;
case 1<<9:fft<1<<9>(a);break;
case 1<<10:fft<1<<10>(a);break;
case 1<<11:fft<1<<11>(a);break;
case 1<<12:fft<1<<12>(a);break;
case 1<<13:fft<1<<13>(a);break;
case 1<<14:fft<1<<14>(a);break;
case 1<<15:fft<1<<15>(a);break;
case 1<<16:fft<1<<16>(a);break;
case 1<<17:fft<1<<17>(a);break;
case 1<<18:fft<1<<18>(a);break;
case 1<<19:fft<1<<19>(a);break;
case 1<<20:fft<1<<20>(a);break;
case 1<<21:fft<1<<21>(a);break;
}
}
频率抽取(DIF)
刚才的 FFT 是按照下标奇偶性分的,其实也可以按前后分,令
刚才的 FFT(时间抽取,DIT)是先递归计算,再进行合并,而现在,我们可以先直接计算每组
//DIF-FFT
template<const int n>
void fft(cp a[]){
if(n<=1)return;
const int half=n>>1;
cp w(1,0),wn(cos(pi2/n),sin(pi2/n));
for(int i=0;i<half;i++){
cp x=a[i],y=a[i+half];
a[i]=x+y;
a[i+half]=(x-y)*w;
w*=wn;
}
fft<half>(a);
fft<half>(a+half);
}
template<>
void fft<1>(cp a[]){}
template<>
void fft<0>(cp a[]){}
void runfft(cp a[],const int& n){
//直接开始
switch(n){
case 1<<1:fft<1<<1>(a);break;
case 1<<2:fft<1<<2>(a);break;
case 1<<3:fft<1<<3>(a);break;
case 1<<4:fft<1<<4>(a);break;
case 1<<5:fft<1<<5>(a);break;
case 1<<6:fft<1<<6>(a);break;
case 1<<7:fft<1<<7>(a);break;
case 1<<8:fft<1<<8>(a);break;
case 1<<9:fft<1<<9>(a);break;
case 1<<10:fft<1<<10>(a);break;
case 1<<11:fft<1<<11>(a);break;
case 1<<12:fft<1<<12>(a);break;
case 1<<13:fft<1<<13>(a);break;
case 1<<14:fft<1<<14>(a);break;
case 1<<15:fft<1<<15>(a);break;
case 1<<16:fft<1<<16>(a);break;
case 1<<17:fft<1<<17>(a);break;
case 1<<18:fft<1<<18>(a);break;
case 1<<19:fft<1<<19>(a);break;
case 1<<20:fft<1<<20>(a);break;
case 1<<21:fft<1<<21>(a);break;
}
//收尾再二进制逆序置换
rev(n,a);
}
这一方法的使用将在快速傅里叶逆变换中提到。
分裂基-DIT
以上的 FFT 是单次将多项式分成
而分裂基 FFT 就是其中一种,它是将基
的基础上,再用同种方法使:
代入原式,得到:
这样分,同时兼具基
参考代码:
//DIT-FFT,分裂基
template<const int n>
void fft(cp a[]){
constexpr int half=n>>1,quarter=n>>2;
ifft<half>(a);
ifft<quarter>(a+half);
ifft<quarter>(a+half+quarter);
cp w(1,0),wn(cos(pi2/n),-sin(pi2/n));
for(int i=0;i<quarter;i++){
cp w3=w*w*w;
cp tmp1=w*a[i+half],tmp2=w3*a[i+half+quarter];
cp x=a[i],y=tmp1+tmp2;
cp x1=a[i+quarter],y1=tmp1-tmp2;
y1=cp(y1.imag(),-y1.real());
a[i]=x+y;
a[i+quarter]=x1+y1;
a[i+half]=x-y;
a[i+half+quarter]=x1-y1;
w*=wn;
}
}
template<>
void fft<2>(cp a[]){
cp x=a[0],y=a[1];
a[0]=x+y;
a[1]=x-y;
}
template<>
void fft<1>(cp a[]){}
template<>
void fft<0>(cp a[]){}
分裂基-DIF
(为了公式,代码与介绍一致且便于叙述,下面称原下标偶数部分为左半,奇数部分为右半。)
思想同上,还是把
这里需要注意:分裂基 DIT 在递归完后,因为左右都处理完了,而右半采取的是基
参考代码:
template<const int n>
void fft(cp a[]){
const int half=n>>1,quarter=n>>2;
cp w(1,0),wn(cos(pi2/n),sin(pi2/n));
for(int i=0;i<quarter;i++){
cp w3=w*w*w;
cp x=a[i]-a[i+half],y=a[i+quarter]-a[i+half+quarter];
y=cp(y.imag(),-y.real());
a[i]=a[i]+a[i+half];
a[i+quarter]=a[i+quarter]+a[i+half+quarter];
a[i+half]=(x-y)*w;
a[i+half+quarter]=(x+y)*w3;
w*=wn;
}
fft<half>(a);
fft<quarter>(a+half);
fft<quarter>(a+half+quarter);
}
template<>
void fft<2>(cp a[]){
cp x=a[0],y=a[1];
a[0]=x+y;
a[1]=x-y;
}
template<>
void fft<1>(cp a[]){}
template<>
void fft<0>(cp a[]){}
快速数论变换(NTT)
前置知识
-
FFT 的思想。
-
OI wiki-数论基础中关于同余的部分。
-
原根的定义。
-
对于质数
p 的原根g ,g^{\frac{p-1}{n}} 在模意义下与n 次单位根有相同的性质。
正文
对于质数
为了
如果选择
如果原多项式的各项的系数不超过 long long 和快速乘,同时选择合适的模数。
优化
-
没有了浮点数误差,我们可以选一个大模数然后压位,实测选择
29\times2^{57}+1 时可以压7 位。 -
两个不超过
p 的数进行加减,结果在区间(-2p,2p) 中,我们可以用判断和加减代替部分模运算。 -
预处理原根的幂并用快速幂优化。
-
同时对两个多项式 NTT/FFT。
代码
using ll=long long;
const int N=(1<<21)+1;
const ll p=119*(1LL<<23)+1,g=3;
ll powg[N];
#define qmul(x,y) (x*y%p)
//历史遗留问题,无意义甚至更差
//快速幂
ll qpow(ll a,ll b){
ll ans=1;
while(b){
if(b&1)ans=qmul(ans,a);
a=qmul(a,a);
b>>=1;
}
return ans;
}
//预处理
void init(const int& n){
ll ig=qpow(g,p-2);
for(int i=1;i<=n;i<<=1)powg[i]=qpow(g,(p-1)/i);
for(int i=1;i<=n;i<<=1)ipowg[i]=qpow(ig,(p-1)/i);
}
void ntt2(ll a[],ll b[],const int& n){
if(n<=1)return;
const int half=n/2;
ntt2(a,b,half);
ntt2(a+half,b+half,half);
ll w=1,wn=powg[n];
for(int i=0;i<half;i++){
ll x,y;
x=a[i];
y=qmul(a[i+half],w);
a[i]=x+y;
if(a[i]>=p)a[i]-=p;
a[i+half]=(x-y+p)%p;
x=b[i];
y=qmul(b[i+half],w);
b[i]=x+y;
if(b[i]>=p)b[i]-=p;
b[i+half]=(x-y+p)%p;
w=qmul(w,wn);
}
}
快速傅里叶逆变换/快速数论逆变换(IFFT/INTT)
笔者才疏学浅,这里直接给出结论:用相乘后的各个点值作系数构成新多项式,用单位根的倒数(或原根的逆元)代替单位根(或原根)进行 FFT/NTT,得到的各项点值再除以多项式的长度(或乘以长度的逆元)就是各项系数。
DIF 和 DIT 混合使用加速
DIF 需要最后进行二进制逆序置换,DIT 需要开始时二进制逆序置换,我们用 DIF 的 FFT,DIT 的 IFFT 就相当于两个都不需要二进制逆序置换了代码又少写一点。
极端优化
前排提醒,这一部分的内容实用性远低于前面各部分,且不给出任何形式的实现,仅供大佬们参考。
-
(OI 禁止使用)使用 AVX2,SIMD 等指令集。
-
在多项式长度较短时改用迭代 FFT。
-
人工内联展开多项式较短的情况(多写几百行且效率提升低。)。
-
预处理甚至打表单位根和其幂。
效率对比
(以洛谷 P3803 为测试题目,均采用快读和三次变两次优化。)
省流:纯递归最慢,迭代次之,模板递归最快,DIF 和 DIT 混用优化和分裂基都很有较大提升,总计时间较最慢的写法(DIT,递归,基
-
DIF,DIT 混用,模板递归,分裂基。
-
DIF,DIT 混用,模板递归,基
2 。 -
DIF,DIT 混用,递归,基
2 。 -
DIF,DIT 混用,迭代,基
2 。 -
DIT,模板递归,基
2 。 -
DIT,递归,基
2 。
参考文献
-
【C++】【FFT】干货分享!超快的开源FFT高精度、大数乘法
-
再探 FFT – DIT 与 DIF,另种推导和优化
-
《算法导论》
-
OI-Wiki
特别鸣谢
-
所有参考文献作者及网站维护人员。
-
我的同学,他提供给了我部分数学知识。