题解P6828 任意模数 Chirp-Z Transform

· · 题解

本文分两部分:Chirp-Z Transform 和 MTT。

Chirp-Z Transform

简介

Chirp-Z 变换(Chirp-Z Transform),又称 Bluestein 算法,是任意长度的卷积。

这个算法可以在 O(n\log n) 的时间复杂度内求解等比数列点值。

更形象化地:

给定三个整数 n,m,c,以及一个多项式 A(x)

这个算法可以在 \Theta((n+m)\log (n+m)) 的时间复杂度内求出

A(1),A(c),A(c^2),\cdots,A(c^{m-1})

当然,是在模意义下的。

实现

我们有几种实现方式。讲讲最主流的方式:

先想想组合数。根据组合数的定义 \binom{n}{m}=\frac{n!}{m!(n-m)!}(n\geq m\geq 0),我们可以得到

\binom{n}{2}=\dfrac{n!}{2!(n-2)!}=\dfrac{n(n-1)}{2}

这里即便除以零也没事,因为如同分式方程,没有必要在中途检验。

首先是这个式子

\begin{aligned} xy&=\dfrac{x^2+y^2+2xy-x^2-y^2}{2}\\ &=\dfrac{x^2+y^2+2xy-x-y-x^2-y^2+x+y}{2}\\ &=\dfrac{x^2+y^2+2xy-x-y}{2}-\dfrac{x^2-x}{2}-\dfrac{y^2-y}{2}\\ &=\binom{x+y}{2}-\binom{x}{2}-\binom{y}{2} \end{aligned}

然后我们将要求的东西代入函数里

\begin{aligned} A(c^{k})&=\sum_{i=0}^{\deg A}c^{ik}A\lbrack i\rbrack\\ &=\sum_{i=0}^{\deg A}c^{\binom{i+k}{2}-\binom{i}{2}-\binom{k}{2}}A\lbrack i\rbrack\\ &=c^{-\binom{k}{2}}\sum_{i=0}^{\deg A}c^{\binom{i+k}{2}-\binom{i}{2}}A\lbrack i\rbrack \end{aligned}

这已经明摆着的是一个卷积的形式了。

我们只需要预处理出 c 的幂,然后做卷积,最后乘上 c^{-\binom{k}{2}} 就可以了。

但是显然这甚至没有直接计算快。我们试着加速。

\begin{aligned} P(x)&=\sum_{i=0}^{\deg A}A\lbrack \deg A-i\rbrack c^{-\binom{\deg A-i}{2}}x^i\\ Q(x)&=\sum_{i=0}^{\deg A}c^{\binom{i}{2}}x^i \end{aligned}

于是我们有

\begin{aligned} P(x)\lbrack i\rbrack&=A\lbrack \deg A-i\rbrack c^{-\binom{\deg A-i}{2}}\\ Q(x)\lbrack i\rbrack&=c^{\binom{i}{2}} \end{aligned}

再推一波式子。

\begin{aligned} (P(x)Q(x))\lbrack\deg A+k\rbrack&=\sum_{i=0}^{\deg A+k}(P(x)\lbrack \deg A+k-i\rbrack\times Q(x)\lbrack i\rbrack)\\ &=\sum_{i=0}^{\deg A+k}(A\lbrack{\color{red}i-k}\rbrack\times c^{-\binom{\color{red}i-k}{2}}\times c^\binom{i}{2}) \end{aligned}

注意到此时 i-k 可能小于零。于是我们设 \forall i\not\in\lbrack0,\deg A\rbrack,A\lbrack i\rbrack=0

在这一步中计算组合数发生除以零是没事的,因为如同分式方程,没有必要在中途检验。

\begin{aligned} (P(x)Q(x))\lbrack\deg A+k\rbrack&=\sum_{i=0}^{\deg A+k}(A\lbrack i-k\rbrack\times c^{-\binom{i-k}{2}}\times c^\binom{i}{2})\\ &=\sum_{i=-k}^{\deg A}(A\lbrack i\rbrack\times c^{\binom{i+k}{2}-\binom{i}{2}})\\ &=\sum_{i=0}^{\deg A}(A\lbrack i\rbrack\times c^{\binom{i+k}{2}-\binom{i}{2}})\\ &=c^{\binom{k}{2}}\times A(c^k) \end{aligned}

于是我们对 P(x)Q(x) 做卷积,可以得到 c^{\binom{0}{2}}\times A(c^0),c^{\binom{1}{2}}\times A(c^1),\cdots,c^{\binom{m-1}{2}}\times A(c^{m-1})

注意到 \binom{0}{2}\binom{1}{2} 是不合法的,所以我们只能套公式。

由于得到的结果是 m 项的,所以我们要做的卷积长度为 \deg A+m(精确来说应该是 2^{\lceil\log(\deg A+m)\rceil})。

代码

inline void ChirpZ(size_t n, mint c, size_t m, vector<mint> a, vector<mint>& b) {
    static vector<mint> pwc, ipwc, f, g;
    pwc.resize(n + m + 1);
    ipwc.resize(n + m + 1);
    size_t i = 1, iend = std::max(n, m);
    pwc[1] = c;
    ipwc[1] = c.inv();
    for (pwc[0] = i = 1; i <= n + m; i++) pwc[i] = pwc[i - 1] * pwc[1];
    for (pwc[0] = i = 1; i <= n + m; i++) pwc[i] = pwc[i] * pwc[i - 1];
    for (ipwc[0] = i = 1; i <= iend; i++) ipwc[i] = ipwc[i - 1] * ipwc[1];
    for (ipwc[0] = i = 1; i <= iend; i++) ipwc[i] = ipwc[i] * ipwc[i - 1];
    for (i = 1; i < n; i++) a[i] *= ipwc[i - 1];
    size_t len = 1;
    int x = -1;
    while (len < n + m) len <<= 1, x++;
    get_rev(len, x);
    f.resize(len);
    g.resize(len);
    f[0] = 1;
    std::copy_n(pwc.begin(), n + m, f.begin() + 1);
    for (i = 1; i <= n; i++) g[i] = a[n - i];
    mul(n + m, n + m, f, g, f, len);
    b.resize(m);
    b[0] = f[n];
    for (i = 1; i < m; i++) b[i] = f[i + n] * ipwc[i - 1];
}

MTT

预处理单位根

还是一个比较简单的提高精度的科技。

通过预处理单位根算出 \omega_{len}^i,而不是每次都乘,这样会掉精度。

同时能减小常数,避免重复的运算。

inline void get_rev(size_t len, int x) {
    if (len == rev.size()) return;
    rev.resize(len);
    for (size_t i = 0; i < len; i++) rev[i] = rev[i >> 1ull] >> 1ull | (i & 1ull) << x;
    omegas.resize(len);
    for (size_t i = 0; i < len; i++) omegas[i] = cp(std::cos(two * pi / len * i), std::sin(two * pi / len * i));
}
inline void FFT(vector<cp>& a, size_t n, bool type) {
    for (size_t i = 1ull; i < n; ++i) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
    for (size_t i = 2ull; i <= n; i <<= 1ull) for (size_t j = 0ull, l = (i >> 1ull), ch = n / i; j < n; j += i) for (size_t k = j, now = 0ull; k < j + l; k++) {
        cp x = a[k], y = a[k + l] * (type ? omegas[now] : omegas[now].conj());
        a[k] = x + y;
        a[k + l] = x - y;
        now += ch;
    }
    if (!type) {
        for (size_t i = 0; i < n; i++) {
            a[i].real() /= n;
            a[i].imag() /= n;
        }
    }
}

第一轮合并

我们发现,在 7 次 FFT 的拆系数 FFT 中,复数中我们只用了实部,并没有使用虚部。我们尝试使用。

我们设

\begin{aligned} P_0(x)&=A_0(x)+\mathrm{i}A_1(x)\\ P_1(x)&=A_0(x)-\mathrm{i}A_1(x)\\ Q(x)&=B_0(x)+\mathrm{i}B_1(x) \end{aligned}

然后我们将他们乘起来,得到

\begin{aligned} P_0(x)Q(x)&=(A_0(x)+\mathrm{i}A_1(x))(B_0(x)+\mathrm{i}B_1(x))\\ &=A_0(x)B_0(x)-A_1(x)B_1(x)+\mathrm{i}(A_0(x)B_1(x)+A_1(x)B_0(x))\\ P_1(x)Q(x)&=(A_0(x)-\mathrm{i}A_1(x))(B_0(x)+\mathrm{i}B_1(x))\\ &=A_0(x)B_0(x)+A_1(x)B_1(x)+\mathrm{i}(A_0(x)B_1(x)-A_1(x)B_0(x))\\ \end{aligned}

于是

\begin{aligned} P_0(x)Q(x)+P_1(x)Q(x)&=2(A_0(x)B_0(x)+\mathrm{i}A_0(x)B_1(x))\\ P_1(x)Q(x)-P_0(x)Q(x)&=2(A_1(x)B_1(x)-\mathrm{i}A_1(x)B_0(x)) \end{aligned}

我们可以在上述式子中找出所有我们所需要的项。

总计 5 次 FFT。

代码:

inline void mul(size_t n, size_t m, const vector<mint>& a, const vector<mint>& b, vector<mint>& c, size_t len) {
    const size_t up = sqrt(mint::mod());
    mint v1, v2, v3;
    static size_t i;
    vector<cp> P(len, cp(0, 0)), Q(len, cp(0, 0)), R(len, cp(0, 0));
    for (i = 0ull; i < n; i++) {
        P[i] = cp(a[i].val() % up, a[i].val() / up);
        Q[i] = P[i].conj();
    }
    for (i = 0ull; i < m; i++) R[i] = cp(b[i].val() % up, b[i].val() / up);
    FFT(P, len, true);
    FFT(Q, len, true);
    FFT(R, len, true);
    for (i = 0ull; i < len; i++) {
        P[i] = P[i] * R[i];
        Q[i] = Q[i] * R[i];
    }
    FFT(P, len, false);
    FFT(Q, len, false);
    c.resize(n + m - 1);
    for (i = 0ull; i < n + m - 1; i++) {
        v1 = llround((P[i].real() + Q[i].real()) / two);
        v2 = llround((Q[i].real() - P[i].real()) / two);
        v3 = llround(P[i].imag());
        c[i] = v1 + v2 * up * up + v3 * up;
    }
}

实测不能通过。

第二轮合并

本质上就是对 P_0(x)P_1(x) 的合并。

第一种合并方式

这是 OIer 最常写的拆系数 FFT 版本。

通过上边提到的特性计算出了 \mathcal{F}(F_0(x))\lbrack i\rbrack\mathcal{F}(F_1(x))\lbrack i\rbrack。但是这种方法不同就在于,它直接求出了 A_0(x),A_1(x)。然后将 A_0(x)B_0(x)+\mathrm{i}A_0(x)B_1(x)A_1(x)B_0(x)+\mathrm{i}A_1(x)B_1(x) 算出。

然后按照定义合并。

第二种合并方式

设 FFT 长度为 n。根据 FFT 的定义 \mathcal{F}(F(x))\lbrack i\rbrack=F(\omega_{n}^i)=\sum_{j=0}^{n-1}\omega_{n}^{ij}F\lbrack j\rbrack,我们能得到

\begin{aligned} \mathcal{F}(P_0(x))\lbrack i\rbrack&=P_0(\omega_{n}^i)\\ &=A_0(\omega_{n}^i)+\mathrm{i}A_1(\omega_{n}^i)\\ &=\sum_{j=0}^{n-1}\omega_{n}^{ij}A_0\lbrack j\rbrack +\mathrm{i} \sum_{j=0}^{n-1}\omega_{n}^{ij}A_1\lbrack j\rbrack\\ &=\sum_{j=0}^{n-1}\omega_{n}^{ij}(A_0\lbrack j\rbrack +\mathrm{i} A_1\lbrack j\rbrack)\\ \mathcal{F}(P_1(x))\lbrack n-i\rbrack&=P_1(\omega_{n}^{-i})\\ &=A_0(\omega_{n}^{-i})-\mathrm{i}A_1(\omega_{n}^{-i})\\ &=\sum_{j=0}^{n-1}\omega_{n}^{-ij}A_0\lbrack j\rbrack -\mathrm{i} \sum_{j=0}^{n-1}\omega_{n}^{-ij}A_1\lbrack j\rbrack\\ &=\sum_{j=0}^{n-1}\omega_{n}^{-ij}(A_0\lbrack j\rbrack-\mathrm{i}A_1\lbrack j\rbrack)\\ \end{aligned}

考虑单位根的几何意义 \omega_{n}^i=\cos\dfrac{2\pi i}{n}+\mathrm{i}\sin\dfrac{2\pi i}{n}

代入上式,得到

\begin{aligned} \mathcal{F}(P_0(x))\lbrack i\rbrack &=\sum_{j=0}^{n-1}\omega_{n}^{ij}(A_0\lbrack j\rbrack +\mathrm{i} A_1\lbrack j\rbrack)\\ &=\sum_{j=0}^{n-1}(\cos\dfrac{2\pi ij}{n}+\mathrm{i}\sin\dfrac{2\pi ij}{n})(A_0\lbrack j\rbrack +\mathrm{i} A_1\lbrack j\rbrack)\\ &=\sum_{j=0}^{n-1}(\cos\dfrac{2\pi ij}{n}\times A_0\lbrack j\rbrack-\sin\dfrac{2\pi ij}{n}\times A_1\lbrack j\rbrack+\mathrm{i}(\sin\dfrac{2\pi ij}{n}\times A_0\lbrack j\rbrack+\cos\dfrac{2\pi ij}{n}\times A_1\lbrack j\rbrack))\\ \mathcal{F}(P_1(x))\lbrack n-i\rbrack&=\sum_{j=0}^{n-1}\omega_{n}^{-ij}(A_0\lbrack j\rbrack-\mathrm{i}A_1\lbrack j\rbrack)\\ &=\sum_{j=0}^{n-1}(\cos\dfrac{-2\pi ij}{n}+\mathrm{i}\sin\dfrac{-2\pi ij}{n})(A_0\lbrack j\rbrack-\mathrm{i}A_1\lbrack j\rbrack)\\ &=\sum_{j=0}^{n-1}(\cos\dfrac{2\pi ij}{n}-\mathrm{i}\sin\dfrac{2\pi ij}{n})(A_0\lbrack j\rbrack-\mathrm{i}A_1\lbrack j\rbrack)\\ &=\sum_{j=0}^{n-1}(\cos\dfrac{2\pi ij}{n}\times A_0\lbrack j\rbrack-\sin\dfrac{2\pi ij}{n}\times A_1\lbrack j\rbrack-\mathrm{i}(\sin\dfrac{2\pi ij}{n}\times A_0\lbrack j\rbrack+\cos\dfrac{2\pi ij}{n}\times A_1\lbrack j\rbrack)\\ \end{aligned}

然后我们发现 \mathcal{F}(P_0(x))\lbrack i\rbrack\mathcal{F}(P_1(x))\lbrack n-i\rbrack 是共轭的。(这个我在 这篇文章 里提到过)

于是我们对 P_0 做 FFT 之后,可以根据定义得到 P_1 FFT 后的结果。

总计 4 次 FFT。

代码:

第一种 4 次 FFT 的拆系数 FFT
inline void mul(int n, int m, const int& p, vector<int> a, vector<int> b, vector<int>& c) {
    static cp da, db, dc, dd, rel(0.5, 0), img(0, -0.5);
    static const int up = 1 << 15;
    static int v1, v2, v3;
    int len = 1, x = -1;
    while (len < n + m) len <<= 1, x++;
    a.resize(len);
    b.resize(len);
    get_rev(len, x);
    for (int i = 0; i < len; i++) {
        a[i] = i < n ? (a[i] % p + p) % p : 0;
        b[i] = i < m ? (b[i] % p + p) % p : 0;
    }
    vector<cp> P(len, cp(0, 0)), Q(len, cp(0, 0)), R(len, cp(0, 0));
    for (int i = 0; i < len; i++) {
        P[i] = cp(a[i] >> 15, a[i] & 32767);
        Q[i] = cp(b[i] >> 15, b[i] & 32767);
    }
    FFT(P, len, 1);
    FFT(Q, len, 1);
    for (int i = 0; i < len; i++) {
        R[i] = P[(len - i) & (len - 1)].conj();
    }
    for (int i = 0; i < len; i++) {
        auto x = P[i];
        auto y = R[i];
        P[i] = (x + y) * rel;
        R[i] = (x - y) * img;
    }
    for (int i = 0; i < len; i++) {
        P[i] = P[i] * Q[i];
        Q[i] = R[i] * Q[i];
    }
    FFT(P, len, -1);
    FFT(Q, len, -1);
    c.clear();
    c.resize(len);
    for (int i = 0; i < len; i++) {
        v1 = (int)(P[i].real() + 0.5);
        v2 = (int)(P[i].imag() + 0.5) + (int)(Q[i].real() + 0.5);
        v3 = (int)(Q[i].imag() + 0.5);
        c[i] = 1ll * v1 * up % p * up % p + 1ll * v2 * up % p + v3 % p;
        c[i] = (c[i] % p + p) % p;
    }
}
第二种 4 次 FFT 的拆系数 FFT
inline void mul(size_t n, size_t m, const vector<mint>& a, const vector<mint>& b, vector<mint>& c, size_t len) {
    const size_t up = sqrt(mint::mod());
    mint v1, v2, v3;
    vector<cp> P(len, cp(0, 0)), Q(len, cp(0, 0)), R(len, cp(0, 0));
    for (size_t i = 0ull; i < n; i++) P[i] = cp(a[i].val() % up, a[i].val() / up);
    for (size_t i = 0ull; i < m; i++) R[i] = cp(b[i].val() % up, b[i].val() / up);
    FFT(P, len, true);
    FFT(R, len, true);
    Q[0] = P[0].conj();
    for (size_t i = 1ull; i < len; i++) Q[i] = P[len - i].conj();
    for (size_t i = 0ull; i < len; i++) {
        P[i] = P[i] * R[i];
        Q[i] = Q[i] * R[i];
    }
    FFT(P, len, false);
    FFT(Q, len, false);
    c.resize(n + m - 1);
    for (size_t i = 0ull; i < n + m - 1; i++) {
        v1 = llround((P[i].real() + Q[i].real()) / two);
        v2 = llround((Q[i].real() - P[i].real()) / two);
        v3 = llround(P[i].imag());
        c[i] = v1 + v2 * up * up + v3 * up;
    }
}