题解 P7092 【计数题】
wlzhouzhuan · · 题解
本题的第一篇题解(出题人好像没给题解?那我来发一波)。。
注意到
对于同一种
其实只需要管到底取出了多少个质数,然后乘上贝尔数即为答案。这里可以通过分治NTT维护取了
特别地,
对于
对于
分治NTT部分有
计算贝尔数这里可以利用贝尔数的
总时间复杂度
代码
#include <bits/stdc++.h>
using namespace std;
#define poly vector<int>
const int N = 1000005;
const int mod = 998244353;
const int G = 3;
const int Gi = 332748118;
inline int add(int a, int b) { return a + b >= mod ? a + b - mod : a + b; }
inline int sub(int a, int b) { return a - b < 0 ? a - b + mod : a - b; }
inline int qpow(int a, int b = mod - 2) {
int res = 1;
while (b > 0) {
if (b & 1) res = 1ll * res * a % mod;
a = 1ll * a * a % mod, b >>= 1;
}
return res;
}
namespace Poly {
poly r, W;
int lim, L;
void getR(int n) {
lim = 1, L = 0;
while (lim <= n) lim <<= 1, L++;
r.resize(lim), W.resize(lim);
for (int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << L - 1);
}
void reverse(poly &a) {
for (int i = 0, j = a.size() - 1; i < j; i++, j--) swap(a[i], a[j]);
}
void wf(poly &a) {
int n = a.size();
for (int i = 0; i < n - 1; i++) a[i] = 1ll * (i + 1) * a[i + 1] % mod;
a[n - 1] = 0;
}
void jf(poly &a) {
int n = a.size() - 1;
for (int i = n - 1; i >= 1; i--) a[i] = 1ll * a[i - 1] * qpow(i) % mod;
a[0] = 0;
}
void NTT(poly &a, int opt) {
for (int i = 0; i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
int Wn = qpow(opt == 1 ? G : Gi, (mod - 1) / (mid << 1));
W[0] = 1;
for (int k = 1; k < mid; k++) W[k] = 1ll * W[k - 1] * Wn % mod;
for (int j = 0; j < lim; j += mid << 1) {
for (int k = 0; k < mid; k++) {
int x = a[j + k], y = 1ll * a[j + k + mid] * W[k] % mod;
a[j + k] = add(x, y);
a[j + k + mid] = sub(x, y);
}
}
}
if (opt == -1) {
int linv = qpow(lim);
for (int i = 0; i < lim; i++) a[i] = 1ll * a[i] * linv % mod;
}
}
poly operator * (poly a, poly b) {
int len = a.size() + b.size() - 1;
getR(len), a.resize(lim), b.resize(lim);
NTT(a, 1), NTT(b, 1);
for (int i = 0; i < lim; i++) a[i] = 1ll * a[i] * b[i] % mod;
NTT(a, -1), a.resize(len);
return a;
}
poly Inv(poly a) {
if (a.size() == 1) return {qpow(a[0])};
int n = a.size();
poly ha = a; ha.resize(n + 1 >> 1);
poly b = Inv(ha);
getR(2 * n), a.resize(lim), b.resize(lim);
NTT(a, 1), NTT(b, 1);
for (int i = 0; i < lim; i++) b[i] = 1ll * b[i] * (mod + 2 - 1ll * a[i] * b[i] % mod) % mod;
NTT(b, -1), b.resize(n);
return b;
}
poly Ln(poly a) {
poly ta = a; wf(ta);
int n = a.size();
a = ta * Inv(a); jf(a);
a.resize(n);
return a;
}
poly Exp(poly a) {
if (a.size() == 1) return {1};
int n = a.size();
poly ta = a; ta.resize(n + 1 >> 1);
poly b = Exp(ta); b.resize(n);
poly tb = Ln(b);
for (int i = 0; i < n; i++) tb[i] = (mod + a[i] - tb[i]) % mod;
tb[0]++;
b = b * tb;
b.resize(n);
return b;
}
}
using namespace Poly;
int vis[N], pr[N], len;
int fac[N], ifac[N];
void init(int n) {
fac[0] = ifac[0] = 1;
for (int i = 1; i <= n; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[n] = qpow(fac[n]);
for (int i = n - 1; i >= 1; i--) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
for (int i = 2; i <= n; i++) {
if (!vis[i]) pr[len++] = i;
for (int j = 0; j < len && pr[j] * i <= n; j++) {
vis[pr[j] * i] = 1;
if (i % pr[j] == 0) break;
}
}
}
poly solve(int l, int r) {
if (l == r) return {1, pr[l] - 1};
int mid = l + r >> 1;
poly L = solve(l, mid), R = solve(mid + 1, r);
return L * R;
}
int main() {
int n, k;
cin >> n >> k;
if (n <= 1 || k == 0) {
puts("0");
return 0;
}
init(n);
vector<int> Bell(len + 1);
for (int i = 1; i <= len; i++) Bell[i] = ifac[i];
Bell = Exp(Bell);
for (int i = 0; i <= len; i++) Bell[i] = 1ll * Bell[i] * fac[i] % mod;
poly res = solve(0, len - 1);
assert(res.size() <= len + 1);
int ans = 0;
if (k == 1) {
for (int i = 1; i <= len; i++) {
if (i & 1) {
ans = (ans + mod - res[i]) % mod;
} else {
ans = (ans + res[i]) % mod;
}
}
} else {
for (int i = 1; i <= len; i++) {
if (i & 1) {
ans = (ans + mod - 1ll * res[i] * Bell[i] % mod) % mod;
} else {
ans = (ans + 1ll * res[i] * Bell[i] % mod) % mod;
}
}
}
cout << ans << '\n';
}