题解 AT3872 [AGC021F] Trinity

· · 题解

分享一个 O(n + m^3) 且不使用 ntt 的做法。

首先直接写出朴素的 dp 式子,转移的详细意义可以参考其他题解:

dp_{m, n} = (1 + \binom{n + 1}{2})dp_{m - 1, n} + \sum_{i < n} \binom{n + 2}{i}dp_{m - 1, i}

最后答案为 \sum_{i = 0}^{n}\binom{n}{i}dp_{m, i}

自然考虑写成 EGF:记 F_m(x) = \sum_{n\geq 0}\frac{dp_{m, n}}{n!}x^n

那么转移式可以改写成如下的微分方程:

F_m = e^xF_{m - 1} + (2e^x - x - 2)F_{m - 1}' + (e^x - x - 1)F_{m - 1}''

之后可以考虑将 F_m 写成 \sum_{i = 1}^{m}\sum_{j = 1}^m f_{m, i, j}x^ie^{jx} 的形式,这样转移的代价就降到了 O(m^2)

最终 O(m^3) 得到 F_m 后只需要考虑每一项的贡献即可 O(m^2) 算出答案。

#include <bits/stdc++.h>

typedef long long ll;

const int M = 200;
const int N = 8000;
const int P = 998244353;

inline int norm(int x) {if (x >= P) x -= P; return x;}
inline int reduce(int x) {if (x < 0) x += P; return x;}
inline void add(int &x, int y) {if ((x += y) >= P) x -= P;}
inline void sub(int &x, int y) {if ((x -= y) < 0) x += P;}
int mpow(int b, int p) {
    int r = 1;
    for (; p; p >>= 1, b = (ll)b * b % P)
        if (p & 1) r = (ll)r * b % P;
    return r;
}

int fct[N + 5], ifct[N + 5];
int binom(int n, int m) {
    if (n < m || m < 0) return 0;
    else return (ll)fct[n] * ifct[m] % P * ifct[n - m] % P;
}
void init() {
    fct[0] = 1;
    for (int i = 1; i <= N; i++) fct[i] = (ll)fct[i - 1] * i % P;
    ifct[N] = mpow(fct[N], P - 2);
    for (int i = N; i >= 1; i--) ifct[i - 1] = (ll)ifct[i] * i % P;
}

int n, m;
int main() {
    init(), scanf("%d%d", &n, &m);
/*
    static int f[N + 5]; f[0] = 1;
    for (int i = 1; i <= m; i++) {
        static int g[N + 5];
        for (int i = 0; i <= n; i++)
            g[i] = f[i], f[i] = reduce(0 - (ll)f[i] * i % P);
        for (int i = 0; i <= n; i++) for (int j = i; j <= n; j++)
            f[j] = (f[j] + (ll)binom(j + 2, i) * g[i]) % P;
    }
    int ans = 0;
    for (int i = 0; i <= n; i++)
        ans = (ans + (ll)ifct[n - i] * f[i]) % P;
    printf("%lld\n", (ll)ans * fct[n] % P);
*/

// F_{m} = e^xF_{m - 1} + (2e^x - x - 2)F_{m - 1}' + (e^x - x - 1)F_{m - 1}''
// ans = n![x^n]F_{m}e^x

    static int f[M + 5][M + 5]; f[0][0] = 1; // \sum f_{i, j}x^{i}e^{jx}
    for (int k = 1; k <= m; k++) {
        static int g[M + 5][M + 5];
        for (int i = 0; i < k; i++) for (int j = 0; j < k; j++)
            g[i][j] = f[i][j], f[i][j] = 0;

        int del = 0;
        for (int i = 0; i < k; i++) for (int j = 0; j < k; j++) {
            add(f[i][j + 1], g[i][j]);

            if (i) {
                del = (ll)i * g[i][j] % P;
                add(f[i - 1][j + 1], norm(del << 1));
                sub(f[i - 1][j], norm(del << 1));
                sub(f[i][j], del);

                del = 2ll * i * j * g[i][j] % P;
                add(f[i - 1][j + 1], del);
                sub(f[i - 1][j], del);
                sub(f[i][j], del);

                if (i > 1) {
                    del = (ll)i * (i - 1) * g[i][j] % P;
                    add(f[i - 2][j + 1], del);
                    sub(f[i - 2][j], del);
                    sub(f[i - 1][j], del);
                }
            }

            del = (ll)j * g[i][j] % P;
            add(f[i][j + 1], norm(del << 1));
            sub(f[i][j], norm(del << 1));
            sub(f[i + 1][j], del);

            del = (ll)j * j * g[i][j] % P;
            add(f[i][j + 1], del);
            sub(f[i][j], del);
            sub(f[i + 1][j], del);
        }
    }

    int ans = 0;
    for (int i = 0; i <= m && i <= n; i++) for (int j = 0; j <= m; j++)
        ans = (ans + (ll)ifct[n - i] * mpow(j + 1, n - i) % P * f[i][j]) % P;
    printf("%lld\n", (ll)ans * fct[n] % P);
}