题解 AT3872 [AGC021F] Trinity
Tiw_Air_OAO · · 题解
分享一个
首先直接写出朴素的 dp 式子,转移的详细意义可以参考其他题解:
最后答案为
自然考虑写成 EGF:记
那么转移式可以改写成如下的微分方程:
之后可以考虑将
最终
#include <bits/stdc++.h>
typedef long long ll;
const int M = 200;
const int N = 8000;
const int P = 998244353;
inline int norm(int x) {if (x >= P) x -= P; return x;}
inline int reduce(int x) {if (x < 0) x += P; return x;}
inline void add(int &x, int y) {if ((x += y) >= P) x -= P;}
inline void sub(int &x, int y) {if ((x -= y) < 0) x += P;}
int mpow(int b, int p) {
int r = 1;
for (; p; p >>= 1, b = (ll)b * b % P)
if (p & 1) r = (ll)r * b % P;
return r;
}
int fct[N + 5], ifct[N + 5];
int binom(int n, int m) {
if (n < m || m < 0) return 0;
else return (ll)fct[n] * ifct[m] % P * ifct[n - m] % P;
}
void init() {
fct[0] = 1;
for (int i = 1; i <= N; i++) fct[i] = (ll)fct[i - 1] * i % P;
ifct[N] = mpow(fct[N], P - 2);
for (int i = N; i >= 1; i--) ifct[i - 1] = (ll)ifct[i] * i % P;
}
int n, m;
int main() {
init(), scanf("%d%d", &n, &m);
/*
static int f[N + 5]; f[0] = 1;
for (int i = 1; i <= m; i++) {
static int g[N + 5];
for (int i = 0; i <= n; i++)
g[i] = f[i], f[i] = reduce(0 - (ll)f[i] * i % P);
for (int i = 0; i <= n; i++) for (int j = i; j <= n; j++)
f[j] = (f[j] + (ll)binom(j + 2, i) * g[i]) % P;
}
int ans = 0;
for (int i = 0; i <= n; i++)
ans = (ans + (ll)ifct[n - i] * f[i]) % P;
printf("%lld\n", (ll)ans * fct[n] % P);
*/
// F_{m} = e^xF_{m - 1} + (2e^x - x - 2)F_{m - 1}' + (e^x - x - 1)F_{m - 1}''
// ans = n![x^n]F_{m}e^x
static int f[M + 5][M + 5]; f[0][0] = 1; // \sum f_{i, j}x^{i}e^{jx}
for (int k = 1; k <= m; k++) {
static int g[M + 5][M + 5];
for (int i = 0; i < k; i++) for (int j = 0; j < k; j++)
g[i][j] = f[i][j], f[i][j] = 0;
int del = 0;
for (int i = 0; i < k; i++) for (int j = 0; j < k; j++) {
add(f[i][j + 1], g[i][j]);
if (i) {
del = (ll)i * g[i][j] % P;
add(f[i - 1][j + 1], norm(del << 1));
sub(f[i - 1][j], norm(del << 1));
sub(f[i][j], del);
del = 2ll * i * j * g[i][j] % P;
add(f[i - 1][j + 1], del);
sub(f[i - 1][j], del);
sub(f[i][j], del);
if (i > 1) {
del = (ll)i * (i - 1) * g[i][j] % P;
add(f[i - 2][j + 1], del);
sub(f[i - 2][j], del);
sub(f[i - 1][j], del);
}
}
del = (ll)j * g[i][j] % P;
add(f[i][j + 1], norm(del << 1));
sub(f[i][j], norm(del << 1));
sub(f[i + 1][j], del);
del = (ll)j * j * g[i][j] % P;
add(f[i][j + 1], del);
sub(f[i][j], del);
sub(f[i + 1][j], del);
}
}
int ans = 0;
for (int i = 0; i <= m && i <= n; i++) for (int j = 0; j <= m; j++)
ans = (ans + (ll)ifct[n - i] * mpow(j + 1, n - i) % P * f[i][j]) % P;
printf("%lld\n", (ll)ans * fct[n] % P);
}