题解:P1919 【模板】高精度乘法 | A*B Problem 升级版
masonxiong · · 题解
前言
本篇题解主要介绍了快速傅里叶变换以及其优化技巧。
作者目前为本题最优解第二名。
快速傅里叶变换详解
快速傅里叶变换(Fast Fourier transform,FFT)是用于计算多项式乘法的重要算法,也是其它多项式算法的根本。
下面我将尽可能详细地讲解它。
前置知识
非常基础。
- 暴力高精度乘法。
- 初中数学知识。
从整数乘法到多项式乘法
首先使用小学学习竖式计算的时候就已经会了的拆位将整数拆开,变成一个多项式的表示形式。
举个例子:
我们把拆开的式子中的
于是你完成了从整数乘法到多项式乘法的转换。非常简单,不是吗。
多项式乘法的一些记号
假设我们有多项式
- 次数:
n 。注意n 次多项式有n+1 项。 - 系数:
f_i 被称作F 的i 次项系数。
别样的多项式表示形式
初中学平面直角坐标系时,我们已经充分掌握了用两个点确定一条直线解析式的能力,使用待定系数法列出方程即可。
我们又知道:
例如直线就是一个一次多项式,它就可以用两个点确定。
那么我们不妨使用
点值表示法进行乘法
我们引用生物学上的一句话:“存在即合理”。这种点值表示法之所以被众多数学家采纳,它一定是有用的,其用处在于:乘法很好计算。
例如:我们有
假设
只要我们从
傅里叶变换的流程
傅里叶采用了这种表示法,设计出了傅里叶变换,其主要流程非常简单自然:
- 将
F(x) 和G(x) 转成点值表示法。 - 求
H(x) 的点值表示法。 - 根据
H(x) 的点值表示法,求H(x) 的系数表示法。
显然步骤 2 比较简单,那么重点在于步骤 1 和 3,它们分别称作:
- 离散傅里叶变换(Discrete Fourier transform,DFT)
- 逆离散傅里叶变换(Inverse discrete Fourier transform,IDFT)
我们的神人为了高效执行这两个步骤,他决定:
- 将复数单位根
\omega_{n+1}^0\sim\omega_{n+1}^n 代入多项式来求点值表示法。
接下来我们来讲解这个神秘方法的原理。但是在这之前,我们先需要学习复数的一些基本知识。
复数
首先我们定义
我们知道,实数
那么这里复数
复数相关的记号:
- 实部:
a 。 - 虚部:
b 。 - 共轭:实部相同,虚部取相反数。
- 模长:极坐标表示法中的
r ,也就是\sqrt{a^2+b^2} 。 - 幅角:极坐标表示法中的
\theta ,也就是\arctan\frac ba 。
复数相关的基本运算:
-
加法:
(a+b\cdot i)+(c+d\cdot i)=(a+c)+(b+d)\cdot i 。实部、虚部分别相加。 -
减法:
(a+b\cdot i)-(c+d\cdot i)=(a-c)+(b-d)\cdot i 。实部、虚部分别相减。 -
乘法:
(a+b\cdot i)\cdot(c+d\cdot i)=(a\cdot c-b\cdot d)+(a\cdot d+b\cdot c)\cdot i 。直接按整式运算法则展开即可。 -
除法:
\frac{a+b\cdot i}{c+d\cdot i}=\frac{a\cdot c+b\cdot d}{c^2+d^2}+\frac{a\cdot d+b\cdot c}{c^2+d^2}\cdot i 。类比“分母有理化”。
复数乘法的几何意义
我们不妨采用更加贴近几何的类似于极坐标的表示法:
解释:
- 极坐标表示法。
- 直接乘,暴力展开,提出
r_1r_2 。 - 三角函数的和差公式。
你发现这得出了一个新的复数:
- 其模长为原先两个复数的模长的乘积。
- 其幅角为原先两个复数的幅角的和。
这一结论非常的重要。无论是数学还是 OI。
单位根
定义:若复数
首先
那么可能单位根的复数模长一定为
由于模长肯定为
我们发现:幅角为
这些复数共有
几何意义上来说,这些单位根把单位圆划分成
单位根相关性质
单位根有一些相当重要的性质:
回归算法本身
现在我们已经学习了所有需要学的内容,下面我们来实现离散傅里叶变换和逆离散傅里叶变换吧。
离散傅里叶变换
假设有
-
F_1(x)=\sum_{i=0}^{\frac n2-1}f_{2i}x^i=f_0x^0+f_2x^1+f_4x^2+\cdots+f_{n-2}x^{\frac 2n-1} -
那么就有
首先代入
解释:
- 直接代入。
- 根据单位根性质 2。
- 根据单位根性质 4。
这个式子有什么用呢?假设我们有
但是我们这时候还缺
解释:
- 直接代入。
- 根据单位根性质 2。
- 单位根可以减去整圈数。
- 根据单位根性质 4。
- 根据单位根性质 5。
再运用这个式子,你就又
那么我们就实现了根据
时间复杂度分析和归并排序类似,为
逆离散傅里叶变换
假设我们对
令
首先展开
解释:
然后将结论代入:
解释:
- 直接代入。
- 交换求和顺序。
分类讨论:
解释:
- 代入。
-
- 显然。
解释:
- 代入。
-
-
- 根据定义,
\omega_n^i 构成等比数列,公比为\omega_n^1 。使用等比数列求和公式可得。 - 显然。
两部分的贡献加起来显然为
然后你发现这个和离散傅里叶变换的式子很像,区别仅有:
- 最后的答案要除以
n 。 - 将
\omega_n^1 替换为\omega_n^{-1} 。
那么我们逆离散傅里叶变换的部分也讲完了。
快速傅里叶变换的流程总结
我们已经知道了两个关键步骤的实现,下面总结一下流程:
- 将两个多项式
F(x) 和G(x) 补齐到n-1 次,其中n=2^k\space(k\in\N^*) 。 - 分别对它们进行离散傅里叶变换,得到点值
F'(k) 和G'(k) 。 - 将点值函数对位相乘,得到新的点值
H'(k) 。 - 对
H'(k) 进行逆离散傅里叶变换,得到多项式H(x) 的系数表示。 - 将
H(x) 缩减到其应有的长度。
高精度乘法优化
如果你按照以上流程写了一份快速傅里叶变换实现的高精度乘法,你会发现要么有些点超时了,要么就是过了但跑的巨慢(6 秒左右,甚至会更慢)。
同时,如果你尝试过 Python 的话,你发现用 Python 写的高精度乘法只用了 1 秒多。况且翻翻 Python 底层代码,发现它用的是
这样大的常数我们是无法忍受的,下面我们对它进行效果显著的优化。
分治转倍增
我们发现在递归分治的过程中,发生了一些很影响效率的事情:
- 分治时,我们将奇数项系数与偶数项系数分开,使用了大量的内存移动。
- 合并时,我们也使用了临时的内存附注计算,使用了大量的内存移动。
- 递归本身带来了额外的额外开销。
那么我们把递归分治转成迭代倍增,就可以解决以上所有问题。
无法直接转化为迭代实现的一个重要原因就是:你不知道一个元素在分治结束后的位置。那么我们为了转成迭代,我们不妨打表,然后发现这样一条规律:一个元素分治后的位置的二进制表示,是其分治前的位置的二进制表示的位逆序置换。
用人话说,就是把原先位置的二进制倒过来就可以得到新的位置了。下面我们对这一结论进行简单的证明。
令
解释:
- 只有一项时不分治,显然位置不动。
- 将奇数项
[k\otimes1=1] 移动到左边,偶数项放到右边。
将其展开,得到:
解释这个式子,就是:将
有了这条结论,我们就可以:
- 在开始时将每个元素移动到其分治后的位置上(模拟分治过程)。
- 然后,倍增地去合并它们。
同时你发现,在使用位逆序置换后,当你想要合并,也就是通过
于是我们避免了原先递归那样额外开数组然后再复制的操作,直接原位合并,省下了空间和时间。那么现在的关键就在于:如何求位逆序置换。
首先我们假设
解释:
于是我们从小到大计算
压位
这是针对高精度计算的经典优化。
将连续的 std::uint32_t (unsigned int) 进行存储。
预处理单位根
这一点是显然的。你单次计算单位根需要调用三角函数,效率非常低。不如我们直接在最开始就把所有单位根算出来。
首先一种朴素的想法是:对于每一个单位根的计算,我们都调用三角函数。这种方法效率是极低的,虽然比之前的方法好一些。
另一种朴素的想法是:先用三角函数计算
我们采取一种折中的方法:对于所有
内存访问优化
普通的逆序位置换方法设计了大量的内存随机访问,会增加一定的常数。稍微重新排布一下置换以及合并的顺序,确保内存顺序访问,可以保证缓存友好,加快一定的运行速度。
比如我的代码使用了一种隐式的逆序位置换的策略,甚至还避免了重排。
同时,你可以使用迭代器以及指针来代替朴素的数组下标访问以获得更小的常数。
更高效的复数类
你发现 std::complex 似乎速度并不快,它底层的乘法实现仍然是四次乘法。
不妨我们自己写一个基于 AVX2 指令集 __m128d 的复数类:
struct Complex {
__m128d value;
Complex() = default;
Complex(__m128d initialValue) : value(initialValue) {}
Complex(double real, double imagine = 0.0) : value(_mm_set_pd(imagine, real)) {}
Complex operator+(Complex other) const {
return value + other.value;
}
Complex operator-(Complex other) const {
return value - other.value;
}
Complex operator*(Complex other) const {
return _mm_fmaddsub_pd(_mm_unpacklo_pd(value, value), other.value, _mm_unpackhi_pd(value, value) * _mm_permute_pd(other.value, 1));
}
Complex operator*(double other) const {
return value * _mm_set1_pd(other);
}
double real() {
return value[0];
}
double imag() {
return value[1];
}
};
代码
完整代码见提交记录,这里只放上快速傅里叶变换相关的部分。
值得一提的是,这个模板居然跑的比那些大力循环展开(放眼望去,满屏幕的 switch case)的代码跑的快了大约 120ms。
#include <bits/stdc++.h>
struct Transform {
static constexpr std::int64_t Base = 100000000, FFTBase = 10000;
static constexpr std::double_t Pi = 3.141592653589793;
static std::vector<std::complex<std::double_t>> omega;
static void initialize(std::size_t length) {
if (length > omega.size() << 1) {
std::size_t temp = std::__lg(length - 1);
omega.resize(length = 1 << temp), omega[0] = 1.0;
for (std::size_t i = 1; i < length; i <<= 1)
omega[i] = std::polar(1.0, Pi / (i << 1));
for (std::size_t i = 1; i < length; ++i)
omega[i] = omega[i & (i - 1)] * omega[i & -i];
}
}
static void DFT(std::vector<std::complex<std::double_t>>& array) {
for (std::size_t left = array.size() >> 1, right = array.size(); left; left >>= 1, right >>= 1) {
for (std::size_t i = 0; i < left; ++i) {
const std::complex<std::double_t> A = array[i], B = array[i + left];
array[i] = A + B, array[i + left] = A - B;
}
auto index = omega.begin() + 1;
for (std::size_t i = right; i < array.size(); i += right, ++index) {
for (std::size_t j = i; j < i + left; ++j) {
const std::complex<std::double_t> A = array[j], B = array[j + left] * *index;
array[j] = A + B, array[j + left] = A - B;
}
}
}
}
static void IDFT(std::vector<std::complex<std::double_t>>& array) {
for (std::size_t left = 1, right = 2; left < array.size(); left <<= 1, right <<= 1) {
for (std::size_t i = 0; i < left; ++i) {
const std::complex<std::double_t> A = array[i], B = array[i + left];
array[i] = A + B, array[i + left] = A - B;
}
auto index = omega.begin() + 1;
for (std::size_t i = right; i < array.size(); i += right, ++index) {
for (std::size_t j = i; j < i + left; ++j) {
const std::complex<std::double_t> A = array[j], B = array[j + left];
array[j] = A + B, array[j + left] = (A - B) * std::conj(*index);
}
}
}
}
static void convolution(std::vector<std::complex<std::double_t>>& A, std::vector<std::complex<std::double_t>>& B) {
Transform::initialize(A.size()), Transform:: DFT(A), Transform:: DFT(B);
const std::double_t inverse = 1.0 / A.size(), quarterInverse = 0.25 * inverse;
A.front() = std::complex<std::double_t>(
A.front().real() * B.front().real() + A.front().imag() * B.front().imag(),
A.front().real() * B.front().imag() + A.front().imag() * B.front().real()
), A.front() *= inverse, A[1] *= B[1] * inverse;
for (std::size_t k = 2, m = 3; k < A.size(); k <<= 1, m <<= 1) {
for (std::size_t i = k, j = i + k - 1; i < m; ++i, --j) {
const std::complex<std::double_t> oi = A[i] + std::conj(A[j]), hi = A[i] - std::conj(A[j]);
const std::complex<std::double_t> Oi = B[i] + std::conj(B[j]), Hi = B[i] - std::conj(B[j]);
const std::complex<std::double_t> r0 = oi * Oi - hi * Hi * ((i & 1) ? -omega[i >> 1] : omega[i >> 1]), r1 = Oi * hi + oi * Hi;
A[i] = (r0 + r1) * quarterInverse, A[j] = std::conj(r0 - r1) * quarterInverse;
}
}
Transform:: IDFT(A);
}
static void multiply(std::vector<std::int64_t>& A, const std::vector<std::int64_t>& B) {
const std::size_t sizeA = A.size(), sizeB = B.size(), size = 2 << std::__lg(sizeA + sizeB - 1);
std::vector<std::complex<std::double_t>> cA(size, 0.0), cB(size, 0.0);
std::transform(A.begin(), A.end(), cA.begin(), [](std::int64_t element) {
return std::complex<std::double_t>(element % FFTBase, element / FFTBase);
});
std::transform(B.begin(), B.end(), cB.begin(), [](std::int64_t element) {
return std::complex<std::double_t>(element % FFTBase, element / FFTBase);
});
std::int64_t sum = 0;
convolution(cA, cB), A.resize(sizeA + sizeB);
std::transform(cA.begin(), cA.begin() + A.size(), A.begin(), [&](std::complex<std::double_t> element) {
sum += std::int64_t(element.real() + 0.5) + std::int64_t(element.imag() + 0.5) * FFTBase;
std::int64_t result = sum % Base;
return sum /= Base, result;
});
}
};