题解 P7092 【计数题】

· · 题解

本题的第一篇题解(出题人好像没给题解?那我来发一波)。。

注意到 \mu(\prod_{i\in S}i) 只有当每种质因数出现次数 \le 1 的时候才 \neq 0 ,因此 \prod_{i\in S}i 一定可以表示为 \prod_{i=1}^{k}p_i 的形式。

对于同一种 \prod_{i=1}^{k}p_i ,满足条件的 S 集合个数显然是 Bell_k (贝尔数)。可以这么考虑,相当于把 k 个不同的小球放入任意多个本质相同的组内,也就是第二类斯特林数 \sum\limits_{i=0}^{k} S(k,i)=Bell_k

其实只需要管到底取出了多少个质数,然后乘上贝尔数即为答案。这里可以通过分治NTT维护取了 k 个质数的 \varphi 之和。

特别地,

对于 k=0\mu 全为 0 ,故答案为 0

对于 k=1S 内部全是质数,那么不需要乘上贝尔数。

分治NTT部分有 \frac{n}{logn} 个质数,这里复杂度 O(\frac{n}{logn}log^2(\frac{n}{logn}))=O(nlogn)

计算贝尔数这里可以利用贝尔数的 EGF ,即 \sum\limits_{i}Bell_i\frac{x^i}{i!}=e^{e^x-1} ,多项式exp即可,复杂度 O(\frac{n}{logn}log(\frac{n}{logn}))=O(n)

总时间复杂度 O(nlogn) ,实测 n=10^6 跑的飞快。。

代码

#include <bits/stdc++.h>
using namespace std;

#define poly vector<int>

const int N = 1000005;
const int mod = 998244353;
const int G = 3;
const int Gi = 332748118;

inline int add(int a, int b) { return a + b >= mod ? a + b - mod : a + b; }
inline int sub(int a, int b) { return a - b < 0 ? a - b + mod : a - b; }
inline int qpow(int a, int b = mod - 2) {
  int res = 1;
  while (b > 0) {
    if (b & 1) res = 1ll * res * a % mod;
    a = 1ll * a * a % mod, b >>= 1;
  }
  return res;
}

namespace Poly {
  poly r, W;
  int lim, L;
  void getR(int n) {
    lim = 1, L = 0;
    while (lim <= n) lim <<= 1, L++;
    r.resize(lim), W.resize(lim);
    for (int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << L - 1);
  }
  void reverse(poly &a) {
    for (int i = 0, j = a.size() - 1; i < j; i++, j--) swap(a[i], a[j]);
  }
  void wf(poly &a) {
    int n = a.size();
    for (int i = 0; i < n - 1; i++) a[i] = 1ll * (i + 1) * a[i + 1] % mod;
    a[n - 1] = 0;
  }
  void jf(poly &a) {
    int n = a.size() - 1;
    for (int i = n - 1; i >= 1; i--) a[i] = 1ll * a[i - 1] * qpow(i) % mod;
    a[0] = 0;
  }
  void NTT(poly &a, int opt) {
    for (int i = 0; i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]);
    for (int mid = 1; mid < lim; mid <<= 1) {
      int Wn = qpow(opt == 1 ? G : Gi, (mod - 1) / (mid << 1));
      W[0] = 1;
      for (int k = 1; k < mid; k++) W[k] = 1ll * W[k - 1] * Wn % mod;
      for (int j = 0; j < lim; j += mid << 1) {
        for (int k = 0; k < mid; k++) {
          int x = a[j + k], y = 1ll * a[j + k + mid] * W[k] % mod;
          a[j + k] = add(x, y);
          a[j + k + mid] = sub(x, y);
        }
      }
    } 
    if (opt == -1) {
      int linv = qpow(lim);
      for (int i = 0; i < lim; i++) a[i] = 1ll * a[i] * linv % mod;
    }
  }
  poly operator * (poly a, poly b) {
    int len = a.size() + b.size() - 1;
    getR(len), a.resize(lim), b.resize(lim);
    NTT(a, 1), NTT(b, 1);
    for (int i = 0; i < lim; i++) a[i] = 1ll * a[i] * b[i] % mod;
    NTT(a, -1), a.resize(len);
    return a;
  }
  poly Inv(poly a) {
    if (a.size() == 1) return {qpow(a[0])};
    int n = a.size();
    poly ha = a; ha.resize(n + 1 >> 1);
    poly b = Inv(ha);
    getR(2 * n), a.resize(lim), b.resize(lim);
    NTT(a, 1), NTT(b, 1);
    for (int i = 0; i < lim; i++) b[i] = 1ll * b[i] * (mod + 2 - 1ll * a[i] * b[i] % mod) % mod;
    NTT(b, -1), b.resize(n);
    return b;
  }
  poly Ln(poly a) {
    poly ta = a; wf(ta);
    int n = a.size();
    a = ta * Inv(a); jf(a);
    a.resize(n);
    return a;
  }
  poly Exp(poly a) {
    if (a.size() == 1) return {1};
    int n = a.size();
    poly ta = a; ta.resize(n + 1 >> 1);
    poly b = Exp(ta); b.resize(n);
    poly tb = Ln(b);
    for (int i = 0; i < n; i++) tb[i] = (mod + a[i] - tb[i]) % mod;
    tb[0]++;
    b = b * tb;
    b.resize(n);
    return b;
  }
}
using namespace Poly;

int vis[N], pr[N], len;
int fac[N], ifac[N];
void init(int n) {
  fac[0] = ifac[0] = 1;
  for (int i = 1; i <= n; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
  ifac[n] = qpow(fac[n]);
  for (int i = n - 1; i >= 1; i--) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
  for (int i = 2; i <= n; i++) {
    if (!vis[i]) pr[len++] = i;
    for (int j = 0; j < len && pr[j] * i <= n; j++) {
      vis[pr[j] * i] = 1;
      if (i % pr[j] == 0) break;
    } 
  }
}

poly solve(int l, int r) {
  if (l == r) return {1, pr[l] - 1};
  int mid = l + r >> 1;
  poly L = solve(l, mid), R = solve(mid + 1, r);
  return L * R;
}

int main() {
  int n, k;
  cin >> n >> k;
  if (n <= 1 || k == 0) {
    puts("0");
    return 0;
  }
  init(n);
  vector<int> Bell(len + 1);
  for (int i = 1; i <= len; i++) Bell[i] = ifac[i];
  Bell = Exp(Bell);
  for (int i = 0; i <= len; i++) Bell[i] = 1ll * Bell[i] * fac[i] % mod;
  poly res = solve(0, len - 1);
  assert(res.size() <= len + 1);
  int ans = 0;
  if (k == 1) {
    for (int i = 1; i <= len; i++) {
      if (i & 1) {
        ans = (ans + mod - res[i]) % mod;
      } else {
        ans = (ans + res[i]) % mod;
      }
    }
  } else {
    for (int i = 1; i <= len; i++) {
      if (i & 1) {
        ans = (ans + mod - 1ll * res[i] * Bell[i] % mod) % mod;
      } else {
        ans = (ans + 1ll * res[i] * Bell[i] % mod) % mod; 
      }
    }
  }
  cout << ans << '\n';
}