题解:P4948 数列求和

· · 题解

请注意,本文中的 n,k 是反的,且 x=a

之前我研究过这个问题,当时给出了一个 O(n^2 \log k)任意模数做法。本来计划把这题投到本校高一 OIer 入校测试 T2。昨天 hwh 突然把这题题号发给我,我改了一下通过了,发现题解区没有跟我一样的做法?

那么我们给出一个 O(n^2 \log k) 的任意模数分治做法。

S_{n}(k)=\sum_{i=1}^k i^n x^i

k 的奇偶讨论:

  1. k 为奇数时,有
\begin{aligned} S_{n}(k)&=\sum_{i=1}^k i^n x^i\\ &=x +\sum_{i=2}^{k} i^n x^i\\ &=x+x\sum_{i=1}^{k-1} (i+1)^n x^i \\ &=x+x\sum_{i=1}^{k-1}x^i \sum_{j=0}^n \binom{n}{j} i^j\\ &=x+x\sum_{j=0}^n \binom{n}{j} \sum_{i=1}^{k-1} i^j x^i \\ &= x+x\sum_{j=0}^n \binom{n}{j} S_{j}(k-1) \end{aligned}
  1. k 为偶数时,有
\begin{aligned} S_{n}(k)&=\sum_{i=1}^k i^n x^i\\ &=\sum_{i=1}^{k/2} i^n x^i+\sum_{i=k/2+1}^{k} i^n x^i \\ &=S_n(k/2)+x^{k/2} \sum_{i=1}^{k/2} (i+k/2)^n x^i\\ &=S_n(k/2)+x^{k/2} \sum_{i=1}^{k/2} x^i \sum_{j=0}^n \binom{n}{j} i^j (k/2)^{n-j}\\ &=S_n(k/2)+x^{k/2} \sum_{j=0}^n \binom{n}{j} (k/2)^{n-j}\sum_{i=1}^{k/2} i^j x^i \\ &=S_n(k/2)+x^{k/2} + \sum_{j=0}^n \binom{n}{j} (k/2)^{n-j} S_j(k/2) \end{aligned}

递归求解即可,边界为 k=1,此时 S_n(k)=x

需要注意,偶数情况的 (k/2)^{n-j} 不能直接快速幂,否则复杂度是 O(n^2 \log^2 k)。只需倒序枚举 j 即可做到 O(n^2 \log k),可以通过。

如果使用杨辉三角递推组合数,该做法支持任意模数。

constexpr int N = 2005;

long long n, k;
mint x, c[N][N];

void prework() {
    for (int i = 0; i <= n; i++) c[i][0] = c[i][i] = 1;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j < i; j++) c[i][j] = c[i - 1][j] + c[i - 1][j - 1];
    }
}

vector<mint> solve(long long k) {
    if (k == 1) return vector<mint>(n + 1, x);
    vector<mint> S(n + 1);
    if (k & 1) {
        auto T = solve(k - 1);
        for (int i = 0; i <= n; i++) {
            mint cur = 0;
            for (int j = 0; j <= i; j++) cur += c[i][j] * T[j];
            S[i] = x + x * cur;
        }
        return S;
    } else {
        auto T = solve(k >> 1);
        mint p = x.pow(k >> 1);
        for (int i = 0; i <= n; i++) {
            mint cur = 0, b = k >> 1, pw = 1;
            for (int j = i; j >= 0; j--, pw *= b) cur += c[i][j] * T[j] * pw;
            S[i] = T[i] + p * cur;
        }
        return S;
    }
}

void _main() {
    cin >> k >> x >> n;
    prework();
    auto S = solve(k);
    cout << S[n];
}