题解:P1919 【模板】高精度乘法 | A*B Problem 升级版

· · 题解

前言

本篇题解主要介绍了快速傅里叶变换以及其优化技巧。

作者目前为本题最优解第二名。

快速傅里叶变换详解

快速傅里叶变换(Fast Fourier transform,FFT)是用于计算多项式乘法的重要算法,也是其它多项式算法的根本。

下面我将尽可能详细地讲解它。

前置知识

非常基础。

从整数乘法到多项式乘法

首先使用小学学习竖式计算的时候就已经会了的拆位将整数拆开,变成一个多项式的表示形式。

举个例子:

114514=1\cdot10^5+1\cdot10^4+4\cdot10^3+5\cdot10^2+1\cdot10^1+4\cdot10^0

我们把拆开的式子中的 10 看作 x,得到 114514 所对应的多项式 F(x)

F(x)=1\cdot x^5+1\cdot x^4+4\cdot x^3+5\cdot x^2+1\cdot x^1+4\cdot x^0

于是你完成了从整数乘法到多项式乘法的转换。非常简单,不是吗。

多项式乘法的一些记号

假设我们有多项式 F(x)=\sum_{i=0}^nf_ix^i,那么:

别样的多项式表示形式

初中学平面直角坐标系时,我们已经充分掌握了用两个点确定一条直线解析式的能力,使用待定系数法列出方程即可。

我们又知道:n 个等式可以得出 n 元方程的解。那么我们得出结论:n+1 个点可以确定一个 n 次多项式

例如直线就是一个一次多项式,它就可以用两个点确定。

那么我们不妨使用 n+1 个点来表示多项式,这些点你可以随便选。这就是多项式的点值表示法

点值表示法进行乘法

我们引用生物学上的一句话:“存在即合理”。这种点值表示法之所以被众多数学家采纳,它一定是有用的,其用处在于:乘法很好计算。

例如:我们有 n 次多项式 F(x)m 次多项式 G(x) 的点值表示法,求 n+m 次多项式 H(x)=F(x)\cdot G(x) 的点值表示法。

假设 F(x) 上有 (a,b)G(x) 上有 (a,c),那么 H(x) 上就有 (a,H(a))=(a,F(a)\cdot G(a))=(a,b\cdot c)。我们这样就获得了 H(x) 上的一个点。

只要我们从 F(x)G(x) 上各找 n+m+1 个点,我们就能求出 H(x) 的点值表示法了,非常的方便。

傅里叶变换的流程

傅里叶采用了这种表示法,设计出了傅里叶变换,其主要流程非常简单自然:

  1. F(x)G(x) 转成点值表示法。
  2. H(x) 的点值表示法。
  3. 根据 H(x) 的点值表示法,求 H(x) 的系数表示法。

显然步骤 2 比较简单,那么重点在于步骤 1 和 3,它们分别称作:

我们的神人为了高效执行这两个步骤,他决定:

接下来我们来讲解这个神秘方法的原理。但是在这之前,我们先需要学习复数的一些基本知识。

复数

首先我们定义 i^2=-1。复数是形如 a+b\cdot i\space (a,b\in\R) 的数。

我们知道,实数 a 可以用实数轴上的一个点进行表示。现在复数多了 b\cdot i 这一项,我们不妨拉出一条纵向的虚数轴,组成一个复平面直角坐标系

那么这里复数 a+b\cdot i 就可以与点 (a,b) 对应了。当然,你也可以使用类似于极坐标系的表示方法,把复数表示成 r(\cos\theta+i\sin\theta) 的形式,其中 r 是这个点到原点的距离,\theta 是这个点到原点组成的直线与实数轴形成的夹角。

复数相关的记号:

复数相关的基本运算:

复数乘法的几何意义

我们不妨采用更加贴近几何的类似于极坐标的表示法:

\begin{align} r_1(\cos\theta_1+i\sin\theta_1)\cdot r_2(\cos\theta_2+i\sin\theta_2)\\ =r_1r_2(\cos\theta_1\cos\theta_2-\sin\theta_1\sin\theta_2+i(\sin\theta_1\cos\theta_2+\cos\theta_1\sin\theta_2))\\ =r_1r_2(\cos(\theta_1+\theta_2)+i\sin(\theta_1+\theta_2)) \end{align}

解释:

  1. 极坐标表示法。
  2. 直接乘,暴力展开,提出 r_1r_2
  3. 三角函数的和差公式。

你发现这得出了一个新的复数:

这一结论非常的重要。无论是数学还是 OI。

单位根

定义:若复数 c 满足 c^n=1\space(n\in\N^*),则称 c 为一个 n 次单位根。如何求解呢?

首先 1 可以表示为 1+0\cdot i,也就是一个模长为 1 的复数。根据前一节提到的复数相乘几何意义,若 c 的模长大于 1,那模长肯定越乘越长,不可能达到 1;若 c 的模长小于 1,那模长肯定越乘越小,也不可能达到 1

那么可能单位根的复数模长一定为 1。根据几何意义,我们发现:这是复平面直角坐标系上,以原点为圆心,半径为 1 的圆。学过高中三角函数的知道,这是单位圆。

由于模长肯定为 1,我们可以抛开不看,我们只关心幅角。不妨假设 c 的幅角为 \theta,那么根据复数相乘的几何意义,c^n 的幅角为 n\cdot\theta,若它是单位根,那这个角度一定和 1+0\cdot i 的角度重合(也就是和实数轴重合)。

我们发现:幅角为 \frac kn\cdot2\pi\space(k\in\N\cap[0,n))(这里 2\pi 是弧度制表示法,没学过的话把它当 360\degree 看)的复数都是单位根,因为它们的幅角自乘 n 倍后为 k\cdot 2\pi,而 k 是整数,所以说明它旋转了周角的整数倍,也就是转回来了。

这些复数共有 n 个。根据代数基本定理: n 次方程在复数域内有且只有 n 个根,所以这些就是所有的单位根。

几何意义上来说,这些单位根把单位圆划分成 n 个相同的扇形。我们从复数 1 开始,逆时针将这些复数标号 0\sim n-1,分别记作 \omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}

单位根相关性质

单位根有一些相当重要的性质:

回归算法本身

现在我们已经学习了所有需要学的内容,下面我们来实现离散傅里叶变换和逆离散傅里叶变换吧。

离散傅里叶变换

假设有 n=2^k\space(k\in\N^*),对于一个 n-1 次多项式 F(x)=\sum_{i=0}^{n-1}f_ix^i,我们将其按次数的奇偶拆成两个部分:

那么就有 F(x)=F_1(x^2)+x\cdot F_2(x)。下面是离散傅里叶变换的精髓,代入单位根:

首先代入 \omega_n^k 得到:

\begin{align} F(\omega_n^k)=F_1((\omega_n^k)^2)+\omega_n^k\cdot F_2((\omega_n^k)^2)\\ =F_1(\omega_n^{2k})+\omega_n^k\cdot F_2(\omega_n^{2k})\\ =F_1(\omega_{\frac n2}^k)+\omega_n^k\cdot F_2(\omega_{\frac n2}^k) \end{align}

解释:

  1. 直接代入。
  2. 根据单位根性质 2。
  3. 根据单位根性质 4。

这个式子有什么用呢?假设我们有 F_1(x)F_2(x)\omega_{\frac n2}^{0\sim\frac n2-1} 的点值,我们就可以 O(n) 求出 F(x)\omega_n^{0\sim\frac n2-1} 的点值

但是我们这时候还缺 \frac n2 个点,不过我们还可以代入 \omega_n^{k+\frac n2}

\begin{align} F(\omega_n^{k+\frac n2})=F_1((\omega_n^{k+\frac n2})^2)+\omega_n^{k+\frac n2}\cdot F_2((\omega_n^{k+\frac n2})^2)\\ =F_1(\omega_n^{2k+n})+\omega_n^{k+\frac n2}\cdot F_2(\omega_n^{2k+n})\\ =F_1(\omega_n^{2k})+\omega_n^{k+\frac n2}\cdot F_2(\omega_n^{2k})\\ =F_1(\omega_{\frac n2}^k)+\omega_n^{k+\frac n2}\cdot F_2(\omega_{\frac n2}^k)\\ =F_1(\omega_{\frac n2}^k)-\omega_n^k\cdot F_2(\omega_{\frac n2}^k) \end{align}

解释:

  1. 直接代入。
  2. 根据单位根性质 2。
  3. 单位根可以减去整圈数。
  4. 根据单位根性质 4。
  5. 根据单位根性质 5。

再运用这个式子,你就又 O(n) 求出了 F(x)\omega_n^{\frac n2\sim n-1} 的点值。再看看上一个式子,我们就有了 n 个点值。

那么我们就实现了根据 F_1(x)F_2(x) 的点值 O(n) 计算 F(x)。然后我们递归利用这个方法解决 F_1(x)F_2(x) 的求值即可。

时间复杂度分析和归并排序类似,为 O(n\log n),我们实现了离散傅里叶变换!

逆离散傅里叶变换

假设我们对 F(x) 执行离散傅里叶变换,得到了 F(x)\omega_n^{0\sim n-1} 上的点值,现在我们要求 F(x) 的系数表达式,也就是求 f

G(k)F(x)\omega_n^k 上的点值,即 G(k)=F(\omega_n^k)。先说结论:f_k=\frac1n\sum_{i=0}^{n-1}(\omega_n^{-k})^iG(i),我们进行证明。

首先展开 G(k) 的定义:

\begin{align} G(k)=F(\omega_n^k)\\ =\sum_{i=0}^{n-1}(\omega_n^k)^if_i \end{align}

解释:

然后将结论代入:

\begin{align} =\frac1n\sum_{i=0}^{n-1}(\omega_n^k)^i\sum_{j=0}^{n-1}(\omega_n^{-i})^jG(j)\\ =\frac1n\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}\omega_n^{ik-ij}G(j) \end{align}

解释:

  1. 直接代入。
  2. 交换求和顺序。

分类讨论:

\begin{align} \frac1n\sum_{i=0}^{n-1}\omega_n^0G(k)\\ =\frac1n\cdot n\cdot G(k)\\ =G(k) \end{align}

解释:

  1. 代入。
  2. 显然。
\begin{align} \frac1n\sum_{i=0}^{n-1}\omega_n^{i\cdot t}G(j)\\ =\frac1n\sum_{i=0}^{n-1}\omega_n^{i\cdot t}G(k-t)\\ =\frac1n\omega_i^t\bigg(\sum_{i=0}^{n-1}\omega_n^i\bigg)G(k-t)\\ =\frac1n\omega_i^t\cdot0\cdot G(k-t)\\ =0 \end{align}

解释:

  1. 代入。
  2. 根据定义,\omega_n^i 构成等比数列,公比为 \omega_n^1。使用等比数列求和公式可得。
  3. 显然。

两部分的贡献加起来显然为 G(k),这证明了我们的结论是正确的。

然后你发现这个和离散傅里叶变换的式子很像,区别仅有:

那么我们逆离散傅里叶变换的部分也讲完了。

快速傅里叶变换的流程总结

我们已经知道了两个关键步骤的实现,下面总结一下流程:

  1. 将两个多项式 F(x)G(x) 补齐到 n-1 次,其中 n=2^k\space(k\in\N^*)
  2. 分别对它们进行离散傅里叶变换,得到点值 F'(k)G'(k)
  3. 将点值函数对位相乘,得到新的点值 H'(k)
  4. H'(k) 进行逆离散傅里叶变换,得到多项式 H(x) 的系数表示。
  5. H(x) 缩减到其应有的长度。

高精度乘法优化

如果你按照以上流程写了一份快速傅里叶变换实现的高精度乘法,你会发现要么有些点超时了,要么就是过了但跑的巨慢(6 秒左右,甚至会更慢)。

同时,如果你尝试过 Python 的话,你发现用 Python 写的高精度乘法只用了 1 秒多。况且翻翻 Python 底层代码,发现它用的是 O(n^{\log_23})\approx O(n^{1.58}) 的 Karatsuba 分治乘法,复杂度甚至劣于快速傅里叶变换,但跑得却比我们快得多。

这样大的常数我们是无法忍受的,下面我们对它进行效果显著的优化。

分治转倍增

我们发现在递归分治的过程中,发生了一些很影响效率的事情:

那么我们把递归分治转成迭代倍增,就可以解决以上所有问题。

无法直接转化为迭代实现的一个重要原因就是:你不知道一个元素在分治结束后的位置。那么我们为了转成迭代,我们不妨打表,然后发现这样一条规律:一个元素分治后的位置的二进制表示,是其分治前的位置的二进制表示的位逆序置换

用人话说,就是把原先位置的二进制倒过来就可以得到新的位置了。下面我们对这一结论进行简单的证明。

R(n,k) 表示一个 n 项多项式中原先位置为 k 的元素分治后的位置。根据定义,有:

\begin{cases} R(1,k)=k\\ R(n,k)=(k\otimes1)\texttt{ << }(n-1)+R(n-1,k\texttt{ >> }1) \end{cases}

解释:

  1. 只有一项时不分治,显然位置不动。
  2. 将奇数项 [k\otimes1=1] 移动到左边,偶数项放到右边。

将其展开,得到:

R(n,k)=\sum_{i=0}^{n-1}((k\texttt{ >> }i)\otimes1)\texttt{ << }(n-i-1)

解释这个式子,就是:将 k 的第 i 位移动到第 n-i-1 位。这正是“位逆序置换”。

有了这条结论,我们就可以:

同时你发现,在使用位逆序置换后,当你想要合并,也就是通过 F_1(\omega_n^k) 以及 F_2(\omega_n^{k+\frac n2}) 时求 F(\omega_n^k)F(\omega_n^{k+\frac n2}) 时,你发现:

于是我们避免了原先递归那样额外开数组然后再复制的操作,直接原位合并,省下了空间和时间。那么现在的关键就在于:如何求位逆序置换。

首先我们假设 P(x)x 的位逆序置换,那么有 P(x)=P(x\texttt{ >> }1)\texttt{ >> }1+(x\otimes1)\texttt{ << }(n-1)

解释:

于是我们从小到大计算 P(x) 即可。以上是快速傅里叶变换经典优化方法,加上这些优化完全足够了。

压位

这是针对高精度计算的经典优化。

将连续的 w 个数位压缩至一个数进行存储,可以有效提升效率。一般取 w=8,用 32 位无符号整数 std::uint32_t (unsigned int) 进行存储。

预处理单位根

这一点是显然的。你单次计算单位根需要调用三角函数,效率非常低。不如我们直接在最开始就把所有单位根算出来。

首先一种朴素的想法是:对于每一个单位根的计算,我们都调用三角函数。这种方法效率是极低的,虽然比之前的方法好一些。

另一种朴素的想法是:先用三角函数计算 \omega_n^1,然后根据 \omega_n^i\cdot\omega_n^j=\omega_n^{i+j} 依次用 \omega_n^1 递推转移。这种方法效率是高了,但是精度没有保障。

我们采取一种折中的方法:对于所有 i=2^k\space(k\in\N^*)\omega_n^i,我们调用三角函数进行计算;剩下的,我们有 \omega_n^i=\omega_n^{\operatorname{lowbit}(i)}\cdot\omega_n^{i-\operatorname{lowbit}(i)}。这种方法只用 \log n 次三角函数调用,同时每一个 \omega_n^i 都至多在 \log n 次递推内算出。

内存访问优化

普通的逆序位置换方法设计了大量的内存随机访问,会增加一定的常数。稍微重新排布一下置换以及合并的顺序,确保内存顺序访问,可以保证缓存友好,加快一定的运行速度。

比如我的代码使用了一种隐式的逆序位置换的策略,甚至还避免了重排。

同时,你可以使用迭代器以及指针来代替朴素的数组下标访问以获得更小的常数。

更高效的复数类

你发现 std::complex 似乎速度并不快,它底层的乘法实现仍然是四次乘法。

不妨我们自己写一个基于 AVX2 指令集 __m128d 的复数类:

struct Complex {
    __m128d value;

    Complex() = default;

    Complex(__m128d initialValue) : value(initialValue) {}

    Complex(double real, double imagine = 0.0) : value(_mm_set_pd(imagine, real)) {}

    Complex operator+(Complex other) const {
        return value + other.value;
    }

    Complex operator-(Complex other) const {
        return value - other.value;
    }

    Complex operator*(Complex other) const {
        return _mm_fmaddsub_pd(_mm_unpacklo_pd(value, value), other.value, _mm_unpackhi_pd(value, value) * _mm_permute_pd(other.value, 1));
    }

    Complex operator*(double other) const {
        return value * _mm_set1_pd(other);
    }

    double real() {
        return value[0];
    }

    double imag() {
        return value[1];
    }
};

代码

完整代码见提交记录,这里只放上快速傅里叶变换相关的部分。

值得一提的是,这个模板居然跑的比那些大力循环展开(放眼望去,满屏幕的 switch case)的代码跑的快了大约 120ms。

#include <bits/stdc++.h>

struct Transform {
    static constexpr std::int64_t Base = 100000000, FFTBase = 10000;
    static constexpr std::double_t Pi = 3.141592653589793;

    static std::vector<std::complex<std::double_t>> omega;

    static void initialize(std::size_t length) {
        if (length > omega.size() << 1) {
            std::size_t temp = std::__lg(length - 1);
            omega.resize(length = 1 << temp), omega[0] = 1.0;
            for (std::size_t i = 1; i < length; i <<= 1)
                omega[i] = std::polar(1.0, Pi / (i << 1));
            for (std::size_t i = 1; i < length; ++i)
                omega[i] = omega[i & (i - 1)] * omega[i & -i];
        }
    }

    static void  DFT(std::vector<std::complex<std::double_t>>& array) {
        for (std::size_t left = array.size() >> 1, right = array.size(); left; left >>= 1, right >>= 1) {
            for (std::size_t i = 0; i < left; ++i) {
                const std::complex<std::double_t> A = array[i], B = array[i + left];
                array[i] = A + B, array[i + left] = A - B;
            }
            auto index = omega.begin() + 1;
            for (std::size_t i = right; i < array.size(); i += right, ++index) {
                for (std::size_t j = i; j < i + left; ++j) {
                    const std::complex<std::double_t> A = array[j], B = array[j + left] * *index;
                    array[j] = A + B, array[j + left] = A - B;
                }
            }
        }
    }

    static void  IDFT(std::vector<std::complex<std::double_t>>& array) {
        for (std::size_t left = 1, right = 2; left < array.size(); left <<= 1, right <<= 1) {
            for (std::size_t i = 0; i < left; ++i) {
                const std::complex<std::double_t> A = array[i], B = array[i + left];
                array[i] = A + B, array[i + left] = A - B;
            }
            auto index = omega.begin() + 1;
            for (std::size_t i = right; i < array.size(); i += right, ++index) {
                for (std::size_t j = i; j < i + left; ++j) {
                    const std::complex<std::double_t> A = array[j], B = array[j + left];
                    array[j] = A + B, array[j + left] = (A - B) * std::conj(*index);
                }
            }
        }
    }

    static void convolution(std::vector<std::complex<std::double_t>>& A, std::vector<std::complex<std::double_t>>& B) {
        Transform::initialize(A.size()), Transform:: DFT(A), Transform:: DFT(B);
        const std::double_t inverse = 1.0 / A.size(), quarterInverse = 0.25 * inverse;
        A.front() = std::complex<std::double_t>(
            A.front().real() * B.front().real() + A.front().imag() * B.front().imag(),
            A.front().real() * B.front().imag() + A.front().imag() * B.front().real()
        ), A.front() *= inverse, A[1] *= B[1] * inverse;

        for (std::size_t k = 2, m = 3; k < A.size(); k <<= 1, m <<= 1) {
            for (std::size_t i = k, j = i + k - 1; i < m; ++i, --j) {
                const std::complex<std::double_t> oi = A[i] + std::conj(A[j]), hi = A[i] - std::conj(A[j]);
                const std::complex<std::double_t> Oi = B[i] + std::conj(B[j]), Hi = B[i] - std::conj(B[j]);
                const std::complex<std::double_t> r0 = oi * Oi - hi * Hi * ((i & 1) ? -omega[i >> 1] : omega[i >> 1]), r1 = Oi * hi + oi * Hi;
                A[i] = (r0 + r1) * quarterInverse, A[j] = std::conj(r0 - r1) * quarterInverse;
            }
        }
        Transform:: IDFT(A);
    }

    static void multiply(std::vector<std::int64_t>& A, const std::vector<std::int64_t>& B) {
        const std::size_t sizeA = A.size(), sizeB = B.size(), size = 2 << std::__lg(sizeA + sizeB - 1);
        std::vector<std::complex<std::double_t>> cA(size, 0.0), cB(size, 0.0);
        std::transform(A.begin(), A.end(), cA.begin(), [](std::int64_t element) {
            return std::complex<std::double_t>(element % FFTBase, element / FFTBase);
        });
        std::transform(B.begin(), B.end(), cB.begin(), [](std::int64_t element) {
            return std::complex<std::double_t>(element % FFTBase, element / FFTBase);
        });
        std::int64_t sum = 0;
        convolution(cA, cB), A.resize(sizeA + sizeB);
        std::transform(cA.begin(), cA.begin() + A.size(), A.begin(), [&](std::complex<std::double_t> element) {
            sum += std::int64_t(element.real() + 0.5) + std::int64_t(element.imag() + 0.5) * FFTBase;
            std::int64_t result = sum % Base;
            return sum /= Base, result;
        });
    }
};