Solution「『P8158 「PMOI-5」分力』题解」

· · 题解

原题

  • 求下列两式: \dfrac1n\sum_{i=0}^{n-1}\left(\cos\dfrac{2\pi i}{n}-\dfrac{\sum\limits_{j=0}^{n-1}\cos\dfrac{2\pi j}{n}}{n}\right)^k \dfrac1n\sum_{i=0}^{n-1}\left(\sin\dfrac{2\pi i}{n}-\dfrac{\sum\limits_{j=0}^{n-1}\sin\dfrac{2\pi j}{n}}{n}\right)^k
  • 答案对 998\ 244\ 353 取模。

难得的 FFT 应用。题解参考鱼神。

拿到题好像证明这是个有理数都很难证!

Part #1 初步化简式子

我们发现:

\sum_{i=0}^{n-1}\cos\dfrac{2\pi i}{n}=0

\sum_{i=0}^{n-1}\sin\dfrac{2\pi i}{n}=0

这是容易证明的。如果你不会,可以考虑 n 次单位根 \omega_n=\cos\dfrac{2\pi}{n}+\text{i}\sin\dfrac{2\pi}{n}。(根据个人习惯,正体的 \text i 表示虚数单位,而斜体的 i 表示下标)。

因为

1-\omega_n^n=0

所以

\left(1-\omega_n\right)\left(1+\omega_n+\omega_n^2+\cdots+\omega_n^{n-1}\right)=0

因为 1-\omega_n\neq0,所以 1+\omega_n+\omega_n^2+\cdots+\omega_n^{n-1}=0,分别提取实部与虚部即得上面两式。

于是原题即求:

\dfrac1n\sum_{i=0}^{n-1}\cos^k\dfrac{2\pi i}{n} \dfrac1n\sum_{i=0}^{n-1}\sin^k\dfrac{2\pi i}{n}

Part #2 转向 FFT

发现 k 次幂非常难用单位根直接表示。

注意

\cos\theta=\dfrac{\text e^{\text i\theta}+\text e^{-\text i\theta}}{2} \sin\theta=\dfrac{\text e^{\text i\theta}-\text e^{-\text i\theta}}{2\text i}

先求 \dfrac1n\sum\limits_{i=0}^{n-1}\cos^k\dfrac{2\pi i}{n}

\begin{aligned} \dfrac1n\sum_{i=0}^{n-1}\cos^k\dfrac{2\pi i}{n} & = \dfrac1n\sum_{i=0}^{n-1}\left(\dfrac{\text e^{\text i\frac{2\pi i}{n}}+\text e^{-\text i\frac{2\pi i}{n}}}{2}\right)^k \\ & = \dfrac{1}{2^k}\cdot\dfrac1n\sum_{i=0}^{n-1}\left(\omega_n^i+\omega_n^{-i}\right)^k \end{aligned}

注意到:一个序列的平均数就等于这个区间做 IDFT 得到多项式的常数项。如果不知道原因,则需要对 FFT 进行深入了解,可看 OI Wiki 中快速傅里叶逆变换部分。

为什么我们不考虑 DFT 呢?因为 \omega_n 提示我们将它代换为 x,成为关于 x 的多项式。所以我们得到的多项式是 x+x^{-1}。上式在 FFT 的语言下就是,\left(x+x^{-1}\right)^k 的常数项。

好像有点不对?

我们做的是长为 n 的 DFT 与长为 n 的 IDFT,所以实际上我们做的是模 x^n-1 下的循环卷积。

所以上面的式子需要在模 x^n-1 下进行。因为是循环卷积,只能使用快速幂计算。注意乘上 \dfrac{1}{2^k}。时间复杂度 \mathcal O(n\log n\log k)

再求 \dfrac1n\sum\limits_{i=0}^{n-1}\sin^k\dfrac{2\pi i}{n}

\begin{aligned} \dfrac1n\sum_{i=0}^{n-1}\sin^k\dfrac{2\pi i}{n} & = \dfrac1n\sum_{i=0}^{n-1}\left(\dfrac{\text e^{\text i\frac{2\pi i}{n}}-\text e^{-\text i\frac{2\pi i}{n}}}{2\text i}\right)^k \\ & = \dfrac{1}{2^k\cdot\text i^k}\cdot\dfrac1n\sum_{i=0}^{n-1}\left(\omega_n^i-\omega_n^{-i}\right)^k \end{aligned}

于是即求模 x^n-1\left(x-x^{-1}\right)^k 的常数项。注意乘上 \dfrac{1}{2^k\cdot\text i^k}

我们顺便证明了答案一定是有理数。

实现时,注意:

以下是代码。

#include <bits/stdc++.h>
using namespace std;
const int N = 66010, LOGN = 20;
const int mod = 998244353, g = 114514, invg = 137043501;
int n, k, rgt[N];
vector <int> w1[LOGN], w2[LOGN];
inline int power (int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1) {
            ans = (1ll * ans * a) % mod;
        }
        a = (1ll * a * a) % mod;
        b >>= 1;
    }
    return ans;
}
inline int modplus (int a, int b) {
    int x = a + b;
    return (x >= mod? x - mod: x);
}
inline int modminu (int a, int b) {
    int x = a - b;
    return (x < 0? x + mod: x);
}
void init (int n) {
    int k = 0;
    for (; (1 << k) <= n; k++);
    for (int i = 0; i < (1 << k); i++) {
        rgt[i] = (rgt[i >> 1] >> 1) | ((i & 1) << (k - 1));
    }
    for (int i = 1; i <= k; i++) {
        w1[i].resize (1 << i);
        w1[i][0] = 1, w1[i][1] = power (g, (mod - 1) >> i);
        for (int j = 2; j < (1 << i); j++) {
            w1[i][j] = (1ll * w1[i][j - 1] * w1[i][1]) % mod;
        }
        w2[i].resize (1 << i);
        w2[i][0] = 1, w2[i][1] = power (invg, (mod - 1) >> i);
        for (int j = 2; j < (1 << i); j++) {
            w2[i][j] = (1ll * w2[i][j - 1] * w2[i][1]) % mod;
        }
    }
}
void dft (vector <int>& a, int inv) {
    int n = a.size ();
    assert ((n & (-n)) == n);
    for (int i = 0; i < n; i++) {
        if (i < rgt[i]) {
            swap (a[i], a[rgt[i]]);
        }
    }
    for (int mid = 1, st = 1; mid < n; mid <<= 1, st++) {
        for (int i = 0; i < n; i += mid << 1) {
            int* w = &(inv == 1? w1: w2)[st][0];
            for (int j = 0; j < mid; j++) {
                int x = a[i + j], y = 1ll * *(w++) * a[i + j + mid] % mod;
                a[i + j] = modplus (x, y);
                a[i + j + mid] = modminu (x, y);
            }
        }
    }
    if (inv == -1) {
        int ninv = power (n, mod - 2);
        for (int i = 0; i < n; i++) {
            a[i] = (1ll * a[i] * ninv) % mod;
        }
    }
}
struct poly {
    vector <int> a;
    poly () {}
    poly (int n_) {
        a.resize (n_ + 1);
    }
    poly& operator = (int n_) {
        a.resize (n_ + 1);
        return *this;
    }
    int size () {
        return a.size () - 1;
    }
    int& operator [] (int p) {
        return a[p];
    }
    void suit () {
        int n = a.size () - 1;
        int k = 0;
        for (; (1 << k) <= n; k++);
        a.resize (1 << k);
    }
};
poly pcopy (poly pre, int n_) {
    poly ans = pre;
    return ans = n_;
}
poly operator * (poly f, poly g) {
    poly flong = pcopy (f, n * 2 - 2);
    poly glong = pcopy (g, n * 2 - 2);
    flong.suit ();
    glong.suit ();
    dft (flong.a, 1);
    dft (glong.a, 1);
    for (int i = 0; i <= flong.size (); i++) {
        flong[i] = (1ll * flong[i] * glong[i]) % mod;
    }
    dft (flong.a, -1);
    flong = n * 2 - 2;
    for (int i = n; i <= n * 2 - 2; i++) {
        flong[i % n] = modplus (flong[i % n], flong[i]); 
        flong[i] = 0;
    }
    return flong = n;
}
poly f, ans;
int ansx, ansy;
int main () {
    scanf ("%d%d", &n, &k);
    init (n * 2 - 2);
    f = n - 1, ans = n - 1;
    f[1] += 1, f[n - 1] += 1, ans[0] = 1;
    int t = k;
    while (t != 0) {
        if (t & 1) {
            ans = ans * f;
        }
        f = f * f;
        t >>= 1;
    }
    int ex = mod - 1 - k < 0? mod - 1 - k + mod - 1: mod - 1 - k;
    ansx = 1ll * power (2, ex) * ans[0] % mod;
    if (n != 2 && !(k & 1)) {
        for (int i = 0; i < n; i++) {
            f[i] = ans[i] = 0;
        }
        f[1] += 1, f[n - 1] += mod - 1, ans[0] = 1;
        int t = k;
        while (t != 0) {
            if (t & 1) {
                ans = ans * f;
            }
            f = f * f;
            t >>= 1;
        }
        ansy = 1ll * power (2, ex) * ans[0] % mod;
        if ((k >> 1) & 1) {
            ansy = mod - ansy;
        } 
    }
    printf ("%d %d\n", ansx, ansy);
    return 0;
}