P1919 【模板】高精度乘法 | A*B Problem 升级版 - Solution

· · 题解

挂一个人,图文无关。

感觉到目前为止题解区缺少对 FFT 的具体描述,或者要么很不适合初学者进行学习。因此我决定写一篇足够适合初学者学习的题解,因此会适当牺牲一些严谨性和相对困难的证明。

个人认为本文如果你不在乎证明且其余部分较认真的读的话,即使只有初中三年级数学水平也足够看懂。

希望如果有一些疏忽导致的事实性错误可以有人指出,本文可能会视情况修正一些错误以及更新一些扩展。

强烈建议你在知道 \sin\cos 是什么之后再阅读本文,强烈建议你知道弧度值是什么之后再阅读本文。如果你实在不知道什么是弧度值,你就记住 \pi = 180^\circ 即可。

建议你在知道向量是什么之后再阅读本文。

如果可以,建议先知道复数是什么。

注意,如果你要看懂证明的话,至少要有部分线性代数和复数的知识背景。

虽然不是必须的,但最好有一部分线性代数基础。

我们初中的时候就学过了带入系数法求二次函数系数。当然可以扩展到多项式,设我们的多项式是 \sum\limits_{i = 0}^{n} a_ix^i

考虑 n 个点的点值 (x_i,\,y_i),\,0 \le i \le n,\,i \ne j \Rightarrow x_i \ne x_j,以及线性方程组:

\begin{bmatrix} 1 & x_0 & x_0^2 & \cdots & x_0^{n} \\ 1 & x_1 & x_1^2 & \cdots & x_1^{n} \\ 1 & x_2 & x_2^2 & \cdots & x_2^{n} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & x_n & x_n^2 & \cdots & x_n^{n} \\ \end{bmatrix} \begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ \vdots \\ a_n \end{bmatrix} = \begin{bmatrix} y_0 \\ y_1 \\ y_2 \\ \vdots \\ y_n \end{bmatrix}

这个是范德蒙矩阵,可以用归纳法证明其满秩,且行列式值为 \prod\limits_{0 \le j < i \le n} x_i - x_j

所以原线性方程组有唯一解。

那么我们也就可以知道它们乘积的 n + m 个点值 (x_i,\,y_i \times y'_i)

我们知道 n + m 个点值确定一个 n + m - 1 次多项式,而 n 次多项式和 m 次多项式相乘恰好是 n + m - 1 次多项式。所以这 n + m 个点值恰好能够确定我们的乘积(一个 n + m - 1 次多项式)。当然我们的多项式可能是补上来的,次数没有 n + m - 1 次,但是这个并不影响,我们的点值也能唯一确定这个多项式(比如用三点共线确定一条直线)。

由此确定了我们的思路:求点值,对应相乘,再插值。

比如我们完全可以求 (0,\,f(0)),求 (1,\,f(1))

但不管怎么在实数域上选取横坐标,我们的点值仍然不好求。

复数

我们定义 i = \sqrt{-1},这看起来很魔怔,不过确实能帮助我们解决相当多问题,比如接下来的多项式乘法。

我们把形如 z = a + bi 的数称为复数,同时令 \mathbb C 为复数域。那么我们可以定义复数的运算了,z_0 = a + bi,\,z_1 = c + di,\,z_0 + z_1 = (a + c) + (b + d)i,\,z_0 - z_1 = (a + c) - (b + d)i,\,z_0 \times z_1 = (a + bi)(c + di) = (ac - bd) + (ad + bc)i,\,\dfrac{z_0}{z_1} = \dfrac{ac + bd}{c^2 + d^2} + \dfrac{bc - ad}{c^2 + d^2}i

对于复数 z = a + bi,我们将 a 称为 z实部b 称为 z虚部。

Euler Formulae^{\theta i} = \cos(\theta) + i\sin(\theta)

关于欧拉公式,严谨证明涉及到太多东西。这里给个导数方面的理解方法(只能说是理解方法,因为实际上有循环论证)。

我们知道 (e^{xa})' = ae^{ax},那么 (e^{xi})' = ie^{xi} 可以理解为导数是自身的 i 倍,那么画个箭头其实就是每时每刻的移动方向都指向逆时针方向,并且垂直于当前位置代表的向量。这其实就是匀速圆周运动。

de Morive Formula:将复数看成向量,两个复数相乘的几何意义是辐角相加,模长相乘。对应的,两个复数相除的集合意义是辐角相减,模长相除。

向量 \vec z 的辐角:从 x 轴正半轴开始,逆时针转多少度能跟向量 \vec z 同向。

向量 \vec z 的模长就是向量 \vec z 的长度。

于是在复数域下,我们仍然有 n + 1 个点确定一个 n 次复系数多项式。毕竟显然只有 z = 0 的模长是 0,而我们之前 \prod\limits_{0 \le j < i \le n} x_i - x_j。由于 x_i \ne x_j 因此这些复数的模长也不会有一个是 0,得到的行列式结果也不会是 0(不过这里要说明满秩复系数方程组有唯一解,实际上复数不会影响我们证明实数时候的证明)。

Fast Fourier Transform

记得我们之前说过的吗?

我们的优势在于:我们可以任意选定要求的点值的横坐标。

n + 1 个点确定一个 n 次复系数多项式。

因此我们完全可以把我们的横坐标设置为复数,再进行求值和插值,也许就不同了。

以下默认 n2 的非负整数次幂,如果不是那么类似之前说的高位补零即可,这是为了之后的算法可以运行。这对我们的算法只是常数上的影响。

以下默认我们是对某个多项式 f(x)n 个点值,且多项式的次数小于 n

我们发现:把 \omega^{k}_{n} 画成点,其实就是单位圆上的 n 等分点。

单位根的模长都是 1,这样我们就完全不需要关心模长相乘了,只需要关心辐角。

并且通过欧拉公式,我们得到单位根的三个性质:

离散傅里叶变换:已知多项式 f(x),对 0 \le k < n,求(\omega^{k}_{n},\,f(\omega^{k}_{n})),时间复杂度 \Theta(n \log n)

离散傅里叶逆变换:已知 n 个点值 (\omega^{i}_{n},\,y_i)_{i = 0}^{n - 1},求其唯一对应的多项式 f(x),使得 f(x) 的次数为 n - 1f(\omega^{i}_{n}) = y_i。时间复杂度 \Theta(n \log n)

一个小 hint:在 OI 中说 FFT 的复杂度是 \Theta(n \log n) 是没有问题的,确实只进行 \Theta(n \log n) 次浮点数的加减乘除,但是实际上 FFT 的模至少要是 \Theta(\log n) 的,因此理论上我们使用的 SSA 算法是 \Theta(n \log n \log \log n) 的。不过人类已经发明了严格 \Theta(n \log n) 的多项式乘法的算法,参考 David Harvey 和 Joris van der Hoeven 的论文 Integer multiplication in time O(n log n),这实际上是很新的成果。

这东西其实是简单分治算法。

注意以下默认 n \ge 2,如果 n = 1 那我们直接返回就完了。

首先设 f(x) = \sum\limits_{i = 0}^{n - 1}a_ix^i 是我们要求点值的多项式。

现在我们要求 \omega_{n}^{k} 的点值对吧。FFT 会先把 f(x) 奇偶位拆开。

f(x) = a_0 + a_1x + a_2x^2 + \dots + a_{n - 1}x^{n - 1}\\ g(x) = a_0 + a_2x^2 + a_4x^4 + \dots + a_{n - 2}x^{n - 2}\\ h(x) = a_1 + a_3x^3 + a_5x^5 + \dots + a_{n - 1}x^{n - 1}\\ A(x) = a_0 + a_2x + a_4x^2 + \dots + a_{n - 2}x^{\frac{n}{2} - 1}\\ B(x) = a_1 + a_3x + a_5x^2 + \dots + a_{n - 1}x^{\frac{n}{2} - 1}\\ f(x) = A(x^2) + xB(x^2)

现在考虑 x = \omega^{k}_{n}

\begin{aligned} & f(\omega^{k}_{n}) \\ & = A(\omega^{2k}_{n}) + \omega^{k}_{n} \times B(\omega^{2k}_{n})\\ & = A(\omega^{k}_{\frac{n}{2}}) + \omega^{k}_{n} \times B(\omega^{k}_{\frac{n}{2}}) \end{aligned}

感觉规律还不是很明显?考虑一下我们喜闻乐见的 \omega^{k}_{n} = -\omega^{k + \frac{n}{2}}_{n}

\begin{aligned} & f(\omega^{k + \frac{n}{2}}_{n}) \\ & = A(\omega^{2k + n}_{n}) + \omega^{k + \frac{n}{2}}_{n} \times B(\omega^{2k + n}_{n})\\ & = A(\omega^{2k}_{n}) - \omega^{k}_{n} \times B(\omega^{2k}_{n})\\ & = A(\omega^{k}_{\frac{n}{2}}) - \omega^{k}_{n} \times B(\omega^{k}_{\frac{n}{2}}) \end{aligned}

于是我们只要求出来 A(\omega^{k}_{\frac{n}{2}})B(\omega^{k}_{\frac{n}{2}}) 就可以算出 f(\omega^{k}_{n})f(\omega^{k + \frac{n}{2}}_{n}) 了!

很有分治的感觉,现在我们可以设计一个函数 DFT,输入了一个多项式 f(x) 和其次数 + 1 的值 n。最后我们要返回 n 个点,对于第 k 个点横坐标为 \omega^{k}_{n},纵坐标为 f(\omega^{k}_{n})

我们把 A(\omega^{k}_{\frac{n}{2}})B(\omega^{k}_{\frac{n}{2}}) 的表达式写出来,注意这里 k < \frac{n}{2},因为我们可以靠着 k < \frac{n}{2} 这部分算出来 f(\omega^{k + \frac{n}{2}}_{n})

那么对于 A(x),其实就是已知多项式 A(x) 长度为 \frac{n}{2},求 k = 0 \dots \frac{n}{2} - 1,\,(\omega^{k}_{\frac{n}{2}},\,A(\omega^{k}_{\frac{n}{2}}))。对于 B(x),也是已知多项式 B(x) 长度为 \frac{n}{2},求 k = 0 \dots \frac{n}{2} - 1,\,(\omega^{k}_{\frac{n}{2}},\,B(\omega^{k}_{\frac{n}{2}}))

再对比一下,就是已知多项式 A(x)\text{DFT}(A(x)),已知多项式 B(x)\text{DFT}(B(x))

于是我们直接写出代码。

// 注意我们已经将多项式补为了 2 的幂
using ld = double; 
using cp = complex<ld>;
const ld pi = 3.141592653589;
#define rep(i, a, b) for (int i = a; i < b; i++)
void DFT(cp * a, int n) {
    if (n == 1) return;
    int mid = n >> 1; static cp b[N];
    rep(i, 0, mid) b[i] = a[i << 1], b[i + mid] = a[(i << 1) | 1];
    rep(i, 0, n) a[i] = b[i]; DFT(a, mid), DFT(a + mid, mid);
    cp w(1.0, 0.0), wn(cos(2.0 * pi / n), sin(2.0 * pi / n)); // wn 是 ω1n,初始 w 是 ω0n
    rep(i, 0, mid) {
        b[i] = a[i] + w * a[i + mid], b[i + mid] = a[i] - w * a[i + mid];
        w = w * wn;
    }
    rep(i, 0, n) a[i] = b[i];
}

时间复杂度 T(n) = 2T(\frac{n}{2}) + \Theta(n) 由主定理得 T(n) = \Theta(n \log n)

现在我们已经知道了全体单位根的点值,怎么求出来多项式?

我们将单位根设置为其倒数,进行 DFT,然后再将所有数除以 n 即可得到我们的多项式。

你如果没有任何线性代数基础,可以不关心证明。

我们求点值就是多项式当成向量,转置后左乘范德蒙矩阵,而范德蒙矩阵的逆矩阵也很特殊,我们对范德蒙矩阵的每个元素取倒数再除以 n 即可得到范德蒙矩阵的逆矩阵(证明可以自行搜索)。

而我们 DFT 的过程其实就是用 \Theta(n \log n) 的时间复杂度完成了这个单位根组成的矩阵和多项式向量的乘法,我们把单位根的倒数当成单位根,仍然不影响我们单位根的那三个性质,因此可以继续用 DFT 来求。

所以说重点是我们单位根的那三个性质,而不是我们单位根本身,比如 NTT 用的就是模意义下的原根。

因此我们只需要最后再除以 n 就可以了。

using ld = double; 
using cp = complex<ld>;
const ld pi = 3.141592653589;
#define rep(i, a, b) for (int i = a; i < b; i++)
void DFT(cp * a, int n, int f) {
    // 如果传入的参数 f 是 1,则表示是 DFT,否则是 IDFT
    if (n == 1) return;
    int mid = n >> 1; static cp b[N];
    rep(i, 0, mid) b[i] = a[i << 1], b[i + mid] = a[(i << 1) | 1];
    rep(i, 0, n) a[i] = b[i]; DFT(a, mid, f), DFT(a + mid, mid, f);
    cp w(1.0, 0.0), wn(cos(2.0 * pi / n), f * sin(2.0 * pi / n)); // wn 是 ω1n,初始 w 是 ω0n
    rep(i, 0, mid) {
        b[i] = a[i] + w * a[i + mid], b[i + mid] = a[i] - w * a[i + mid]; 
        w = w * wn;
    }
    rep(i, 0, n) a[i] = b[i];
}

然后得到我们的代码:

cp a[N], c[N]; int n;
void work() {
    DFT(a, n, 1), DFT(c, n, 1);
    rep(i, 0, n) a[i] *= c[i]; DFT(a, n, -1);
    rep(i, 0, n) ans[i] = int(a[i].real() / (ld)n + 0.5L);
    rep(i, 0, n) ans[i + 1] += ans[i] / 10, ans[i] = ans[i] % 10;
}

FFT 是可以写成非递归的,我们接下来要从底向上开始,用倍增代替掉递归。

我们模拟拆分过程,实际上对于 i,它最后到达的位置是 i 在二进制下翻转后的结果。比如对于 6 = 110,它最后到达的位置是 3 = 011

我们可以用递推来 \Theta(n) 计算每个数的二进制翻转。

// bit 是 n - 1 的位数
rep(i, 0, n) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); 

这个递推的思想相对简单,留给读者自行思考。

我们可以做到 \Theta(1) 的额外空间进行 DFT。

简单来说,我们知道了 A(\omega^{k}_{\frac{n}{2}})B(\omega^{k}_{\frac{n}{2}}) 对吧,我们二进制翻转后,我们需要保证:

然后我们就可以以从左到右的方式覆写数组的下标,而之前用过的下标之后显然是不会再用了,类似一个刷表法。 ```cpp #include <bits/stdc++.h> using namespace std; const int N = 3e6 + 10; #define rep(i, a, b) for(int i = a; i < b; i++) typedef double ld; using cp = complex<ld>; const ld pi = 3.141592653589; cp u, t; int ans[N], bit = 0, rev[N]; void DFT(cp* a, int n, int f) { if (n == 1) return; rep(i, 0, n) if (i < rev[i]) swap(a[i], a[rev[i]]); for (int mid = 1; mid < n; mid <<= 1) { cp wn(cos(pi / (ld)mid), f * sin(pi / (ld)mid)); // 我们的 "n" for (int i = 0; i < n; i += (mid << 1)) { cp w(1.0, 0.0); for (int j = 0; j < mid; j++, w = w * wn) { u = a[i + j], t = w * a[i + j + mid]; a[i + j] = u + t, a[i + j + mid] = u - t; } } } } cp a[N], c[N]; int n; void work() { DFT(a, n, 1), DFT(c, n, 1); rep(i, 0, n) a[i] *= c[i]; DFT(a, n, -1); rep(i, 0, n) ans[i] = int(a[i].real() / (ld)n + 0.1); rep(i, 0, n) ans[i + 1] += ans[i] / 10, ans[i] = ans[i] % 10; } int main() { string s1, s2; cin >> s1 >> s2; rep(i, 0, s1.size()) a[i] = cp(s1[s1.size() - i - 1] - '0', 0); rep(i, 0, s2.size()) c[i] = cp(s2[s2.size() - i - 1] - '0', 0); n = 1; int len = s1.size() + s2.size() - 1; while (n <= len) n <<= 1, ++bit; rep(i, 0, n) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); work(); len = s1.size() + s2.size() - 1; while (ans[len] == 0 && len) --len; for (int i = len; i >= 0; i--) cout << ans[i]; return 0; } ``` 代码是很久之前写的,可能比较远古。 - **关于一些扩展** FFT 的精度问题其实是比较大的,有些题让我们进行多项式乘法,并且系数对一个数(比如 $998244353$)取模。 或者没有取模,但是任意时刻我们的最大数都不会超过模数。 **快速数论变换 NTT**:模意义下的 FFT。 简单来说如果我们令 $n$ 为大于 $1$ 的 $2$ 的幂,并且 $p$ 为质数满足 $n \mid (p - 1)$,同时我们有 $g$ 是 $p$ 的一个原根。 - 原根你可以这样通俗的理解:一个数 $a$ 是模 $m$ 意义下的原根,当且仅当存在最小的一个 $k$ 使得 $a^k \equiv 1 \pmod m$,并且有 $k = \phi(m)$。 - 如果你学过群论,并且对于质数 $p$ 的剩余系,$a$ 的生成子群等于整个群,那么 $a$ 就是 $p$ 的原根。 这个有啥好处呢,我们设 $g_n = g^\frac{p - 1}{n}$,我们发现 $g^{k}_{n}$ 的性质包含了 $\omega^{k}_{n}$ 那三个性质。或者说,$\omega^{k}_{n}$ 在模 $p$ 意义下就是 $g^{k}_{n}$。 比如 $998244353 = 2^{23} \times 119 + 1$,那么这就是一个相当好用的模数。 然后我们的 $n$ 只要不超过 $2^{23}$,我们就可以用 NTT 来进行多项式乘法了(当然也因为这个,所以 NTT 的模数是相对受限的,比如我们一般不能用 $10^9 + 7$ 作为模数做 NTT)。 **更进一步的扩展**:如果我们要做多项式乘法,但是模数不是 NTT 模数,同时浮点数精度爆了怎么办? 1. 我们用三个模数跑 NTT!然后用 CRT 合并。当然这种方式要跑九遍 NTT,比较魔怔。 2. 把 FFT 拆成两段系数分别跑,据我所知一般是拆成五次 FFT 或者四次 FFT。似乎有拆成三次的做法,也挺魔怔的。 感兴趣者可以看看 P4245 这道题。