P9135 [THUPC 2023 初赛] 快速 LCM 变换

· · 题解

P9135 [THUPC 2023 初赛] 快速 LCM 变换

对质数 p,我们单独分析贡献。设 r_ik_1pk_1 = v_p(r_i)),r_jk_2pr_i + r_jk_3p,且 p 在所有 r_i 的最高次数为 mx_p,第二大次数为 smx_p。不妨设 k_1 \geq k_2

考虑设 c_i 表示若去掉 i,则 \operatorname{lcm} 会变成原来的多少倍,即若 v_p(r_i) = mx_p,则 c_i 乘上 \frac 1 {p ^ {mx_p - smx_p}}。计算 (i, j) 的贡献时,先考虑情况一,当成 c_ic_j 算,再加入 r_i + r_j 的情况二的贡献:若 v_p(r_i + r_j) > mx_p,则 \operatorname{lcm} 乘上 p ^ {v_p(r_i + r_j) - mx_p}

对于每个 k,计算使得 r_i + r_j = kc_ic_j 的贡献是标准卷积形式,NTT 即可。

时间复杂度 \mathcal{O}(V\log V),其中 V 是值域。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ull = unsigned long long;

bool Mbe;
constexpr int N = 1 << 21;
constexpr int mod = 998244353;
constexpr int iv3 = (mod + 1) / 3;
void add(int &x, int y) {
  x += y, x >= mod && (x -= mod);
}
int ksm(int a, int b) {
  int s = 1;
  while(b) {
    if(b & 1) s = 1ll * s * a % mod;
    a = 1ll * a * a % mod, b >>= 1;
  }
  return s;
}

int rev[N];
void NTT(int L, int *a, int type) {
  static ull f[N], w[N];
  for(int i = 0; i < L; i++) {
    rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? L >> 1 : 0);
    f[i] = a[rev[i]];
  }
  for(int d = 1; d < L; d <<= 1) {
    int wn = ksm(type ? 3 : iv3, (mod - 1) / (d + d));
    for(int i = w[0] = 1; i < d; i++) w[i] = w[i - 1] * wn % mod;
    for(int i = 0; i < L; i += d + d)
      for(int j = 0; j < d; j++) {
        int y = f[i | j | d] * w[j] % mod;
        f[i | j | d] = f[i | j] + mod - y, f[i | j] += y; 
      }
    if(d == (1 << 16)) for(int i = 0; i < L; i++) f[i] %= mod;
  }
  int inv = ksm(L, mod - 2);
  for(int i = 0; i < L; i++) a[i] = f[i] % mod * (type ? 1 : inv) % mod;
}

int vis[N], pr[N], mpr[N], cnt;
int n, r[N], coe[N], f[N];
pair<int, int> mx[N], smx[N];
vector<pair<int, int>> buc[N];
bool Med;
int main() {
  fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
  #ifdef ALEX_WEI
    FILE* IN = freopen("1.in", "r", stdin); 
    FILE* OUT = freopen("1.out", "w", stdout);
  #endif
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  for(int i = 2; i < N; i++) {
    if(!vis[i]) pr[++cnt] = mpr[i] = i;
    for(int j = 1; j <= cnt && i * pr[j] < N; j++) {
      vis[i * pr[j]] = 1, mpr[i * pr[j]] = pr[j];
      if(i % pr[j] == 0) break;
    }
    int tmp = i;
    while(tmp > 1) {
      int p = mpr[tmp], c = 0;
      while(tmp % p == 0) tmp /= p, c++;
      buc[i].push_back({p, c});
    }
  }
  cin >> n;
  for(int i = 1; i <= n; i++) {
    cin >> r[i], coe[i] = 1;
    for(auto _ : buc[r[i]]) {
      int p = _.first, c = _.second;
      pair<int, int> v = {c, i};
      if(v > mx[p]) smx[p] = mx[p], mx[p] = v;
      else if(v > smx[p]) smx[p] = v;
    }
  }
  int lcm = 1;
  for(int i = 2; i < N; i++) {
    if(!mx[i].first) continue;
    lcm = 1ll * lcm * ksm(i, mx[i].first) % mod;
    int v = ksm(i, mx[i].first - smx[i].first);
    coe[mx[i].second] = 1ll * coe[mx[i].second] * ksm(v, mod - 2) % mod;
  }
  for(int i = 1; i <= n; i++) add(f[r[i]], coe[i]);
  NTT(N, f, 1);
  for(int i = 0; i < N; i++) f[i] = 1ll * f[i] * f[i] % mod;
  NTT(N, f, 0);
  for(int i = 1; i <= n; i++) add(f[r[i] + r[i]], mod - 1ll * coe[i] * coe[i] % mod);
  int ans = 0;
  for(int i = 1; i < N; i++) {
    if(!f[i]) continue;
    int other = 1;
    for(auto _ : buc[i]) {
      int p = _.first, c = _.second;
      other = 1ll * other * ksm(p, max(0, c - mx[p].first)) % mod;
    }
    add(ans, 1ll * lcm * other % mod * f[i] % mod);
  }
  cout << 1ll * ans * (mod + 1 >> 1) % mod << "\n";
  cerr << 1e3 * clock() / CLOCKS_PER_SEC << " ms\n";
  return 0;
}