题解 P5915 【冬至】

· · 题解

第一部分:动态规划

首先考虑一个简单的 \text{DP}

然而这个做法是 O(nk) 的,远远高于我们的预期,我们需要想办法优化它。

第二部分:矩阵优化

发现题面中的 n 特别大,我们要尽力在这一维上优化。

容易发现这个转移关于第一维来说是常系数、线性的,于是我们可以用矩阵来描述这个转移。具体地:

\left[ \begin{matrix} 1 & 1 & 1 & \cdots & 1 \\ k-1& 1 & 1 & \cdots & 1 \\ 0 &k-2& 1 & \cdots & 1 \\ \vdots&\vdots &\vdots&\ddots& \vdots\\ 0 & 0 & 0 & \cdots & 1 \end {matrix} \right] \times \left[ \begin {matrix} f(i,1)\\ f(i,2)\\ \vdots\\ f(i,k-1) \end {matrix} \right] = \left[ \begin {matrix} f(i + 1, 1)\\ f(i + 1, 2)\\ \vdots\\ f(i + 1, k - 1) \end {matrix} \right]

其中矩阵、向量的维数都是 k-1

我们可以写成这个形式:

A\times F_i=F_{i+1}

因此 A^{n-1}\times F_1=F_n

使用矩阵快速幂就可以得到 O(k^3\log n) 的做法,仍然无法通过。

使用特征多项式进一步优化

现在看起来没有什么优化的途径了?

发现我们在线性递推问题中也遇到了这个瓶颈,当时我们的解决办法是使用矩阵的特征多项式

这里复述一下这个科技:

在这个问题中同样可以使用这个方法。

求特殊矩阵的特征多项式

接着问题来了:如何求这个矩阵 A 的特征多项式呢?

注:接下来的 k 都指题目中的 k-1

F_k(\lambda)=\det(\lambda I-A)=\det\left[ \begin {matrix} \lambda-1&-1&\cdots&-1\\ -k&\lambda-1&\cdots&-1\\ 0&1-k&\cdots&-1\\ \vdots&\vdots&\ddots&\vdots\\ 0&0&\cdots&\lambda-1 \end{matrix} \right]

将这个行列式按第一列展开

我们得到了这种矩阵的特征多项式的递推式,打个表就能得到精确的结果了。

事实上:

F_k(\lambda)=\lambda^k-k\lambda^{k-1}-(k-1)\lambda^{k-2}-2(k-2)\lambda^{k-3}-\dots-(k-1)! =\lambda^k-\sum_{i=1}^k(k-i+1)\cdot(i-1)!\cdot \lambda^{k-i}

得到特征多项式后,沿用线性递推的做法即可。需要注意初始状态。

x^n\bmod f(x) 时使用快速幂,且需要多项式取模。

参考代码:

这个多项式取模是我很早以前的版本,可能码风不太一样。

#include <iostream>
#include <algorithm>
#include <cstring>

const int N = 131072;
const int mod = 998244353;
typedef long long LL;
typedef unsigned long long ULL;

int rev[N], wn[N], lim, s, w[N];

int pow(int x, int y) {
    int ans = 1;
    for (; y; y >>= 1, x = static_cast<LL> (x) * x % mod)
        if (y & 1) ans = static_cast<LL> (ans) * x % mod;
    return ans;
}
void reduce(int &x) {
    x += x >> 31 & mod;
}
void fftinit(int len) {
    wn[0] = lim = 1, s = -1; while (lim < len) lim <<= 1, ++s;
    for (int i = 0; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
    const int g = pow(3, (mod - 1) / lim);
    for (int i = 1; i < lim; ++i) wn[i] = (LL) wn[i - 1] * g % mod;
}
void fft(int *A, int typ) {
    static ULL tmp[N];
    for (int i = 0; i < lim; ++i) tmp[rev[i]] = A[i];
    for (int i = 1; i < lim; i <<= 1) {
        for (int j = 0, t = lim / i / 2; j < i; ++j) w[j] = wn[j * t];
        for (int j = 0; j < lim; j += i << 1)
            for (int k = 0; k < i; ++k) {
                const ULL x = tmp[k + j + i] * w[k] % mod;
                tmp[k + j + i] = tmp[k + j] + mod - x, tmp[k + j] += x;;
            }
    }
    for (int i = 0; i < lim; ++i) A[i] = tmp[i] % mod;
    if (!typ) {
        const int il = pow(lim, mod - 2); std::reverse(A + 1, A + lim);
        for (int i = 0; i < lim; ++i) A[i] = (LL) A[i] * il % mod;
    }
}
void inv(int *A, int *B, int n) {
    if (n == 1) { B[0] = pow(A[0], mod - 2); return; }
    int n_ = n + 1 >> 1; inv(A, B, n_), fftinit(n_ * 3);
    static int C[N], D[N];
    std::memcpy(D, B, n_ << 2), std::memset(D + n_, 0, lim - n_ << 2);
    std::memcpy(C, A, n << 2), std::memset(C + n, 0, lim - n << 2);
    fft(C, 1), fft(D, 1);
    for (int i = 0; i < lim; ++i) C[i] = mod - (LL) C[i] * D[i] % mod * D[i] % mod;
    fft(C, 0), std::memcpy(B + n_, C + n_, n - n_ << 2);
}

int g[N], rg[N], irg[N], k, n, ans, factor[N];

int rA[N];
void Divmod(int *A) {
    int n = 2 * k;
    std::reverse_copy(A, A + n, rA);
    std::memset(rA + k, 0, lim - k << 2);
    fft(rA, 1);
    for (int i = 0; i < lim; ++i)
        rA[i] = static_cast<LL> (rA[i]) * irg[i] % mod;
    fft(rA, 0);
    std::memset(rA + k, 0, lim - k << 2);
    std::reverse(rA, rA + k);
    fft(rA, 1);
    for (int i = 0; i < lim; ++i)
        rA[i] = static_cast<LL> (rA[i]) * g[i] % mod;
    fft(rA, 0);
    for (int i = 0; i < n; ++i)
        reduce(A[i] -= rA[i]);
}
int f[N];
void Powmod(int n) {
    if (n < k) {
        f[n] = 1;
        return;
    }
    Powmod(n >> 1);
    fft(f, 1);
    for (int i = 0; i < lim; ++i)
        f[i] = static_cast<LL> (f[i]) * f[i] % mod;
    fft(f, 0);
    if (n & 1) {
        for (int i = 2 * k - 1; i; --i)
            f[i] = f[i - 1];
        f[0] = 0;
    }
    Divmod(f);
}

int main() {
    std::cin >> n >> k, --k;
    g[k] = 1;
    factor[0] = 1;
    for (int i = 1; i < k; ++i)
        factor[i] = static_cast<LL> (factor[i - 1]) * i % mod;
    for (int i = 0; i < k; ++i)
        g[i] = mod - static_cast<LL> (i + 1) * factor[k - 1 - i] % mod;
    std::reverse_copy(g, g + k + 1, rg);
    inv(rg, irg, k);
    fftinit(k << 1 | 1);
    fft(irg, 1), fft(g, 1);
    Powmod(n - 1);
    for (int i = 0; i < k; ++i)
        reduce(ans += static_cast<LL> (pow(k + 1, i + 1)) * f[i] % mod - mod);
    std::cout << ans << '\n';
    return 0;
}