题解 P7504 【「HMOI R1」可爱的德丽莎】
前置芝士:莫比乌斯反演、杜教筛
本题解省略部分分做法及其代码。
设
设
到此为止,问题转变为了快速求出
想到
显然
代码:
#include <iostream>
#include <map>
#include <cmath>
using namespace std;
typedef long long ll;
const int N = 1587401 + 7, M = 1536 + 7, mod = 998244353, inv2 = 499122177, inv6 = 166374059;
int limit;
int prime[N], phi[N], divisor[M], mu_list[M];
ll f[N], g[N], h[N], f_sum[N], g_sum[N], h_sum[N];
bool p[N];
map<int, ll> mp1, mp2, mp3;
inline void init1(){
int cnt = 0;
p[0] = p[1] = true;
phi[1] = 1;
for (register int i = 2; i <= limit; i++){
if (!p[i]){
prime[++cnt] = i;
phi[i] = i - 1;
}
for (register int j = 1; j <= cnt && i * prime[j] <= limit; j++){
int t = i * prime[j];
p[t] = true;
if (i % prime[j] == 0){
phi[t] = phi[i] * prime[j];
break;
}
phi[t] = phi[i] * (prime[j] - 1);
}
}
}
inline int mu(int n){
int ans = 1;
for (register int i = 2; i * i <= n; i++){
if (n % i == 0){
n /= i;
if (n % i == 0) return 0;
ans = -ans;
}
}
if (n > 1) ans = -ans;
return ans;
}
inline int init2(int n){
int prime_cnt = 0, divisor_cnt = 0, t = sqrt(n);
p[0] = p[1] = true;
f[1] = 1;
g[1] = 1;
h[1] = 1;
for (register int i = 2; i <= limit; i++){
if (!p[i]){
prime[++prime_cnt] = i;
f[i] = n % i == 0 ? 0 : (ll)i * phi[i] % mod;
g[i] = n % i == 0 ? 0 : (ll)i * i % mod;
h[i] = n % i == 0 ? 0 : i;
}
for (register int j = 1; j <= prime_cnt && i * prime[j] <= limit; j++){
int t = i * prime[j];
p[t] = true;
f[t] = f[i] != 0 && n % prime[j] != 0 ? (ll)t * phi[t] % mod : 0;
g[t] = g[i] * g[prime[j]] % mod;
h[t] = h[i] * h[prime[j]] % mod;
if (i % prime[j] == 0) break;
}
}
for (register int i = 1; i < N; i++){
f_sum[i] = (f_sum[i - 1] + f[i]) % mod;
g_sum[i] = (g_sum[i - 1] + g[i]) % mod;
h_sum[i] = (h_sum[i - 1] + h[i]) % mod;
}
for (register int i = 1; i <= t; i++){
if (n % i == 0){
divisor_cnt++;
divisor[divisor_cnt] = i;
mu_list[divisor_cnt] = mu(i);
if (i * i != n){
int tn = n / i;
divisor_cnt++;
divisor[divisor_cnt] = tn;
mu_list[divisor_cnt] = mu(tn);
}
}
}
return divisor_cnt;
}
inline ll sum2(int n){
return (ll)n * (n + 1) % mod * ((ll)n * 2 % mod + 1) % mod * inv6 % mod;
}
inline ll get_g_sum(int n, int cnt){
if (n <= limit) return g_sum[n];
if (mp2.count(n)) return mp2[n];
ll ans = 0;
for (register int i = 1; i <= cnt; i++){
ans = ((ans + (ll)mu_list[i] * divisor[i] % mod * divisor[i] % mod * sum2(n / divisor[i]) % mod) % mod + mod) % mod;
}
return mp2[n] = ans;
}
inline ll sum1(int n){
return (ll)n * (n + 1) / 2 % mod;
}
inline ll get_h_sum(int n, int cnt){
if (n <= limit) return h_sum[n];
if (mp3.count(n)) return mp3[n];
ll ans = 0;
for (register int i = 1; i <= cnt; i++){
ans = ((ans + mu_list[i] * divisor[i] % mod * sum1(n / divisor[i]) % mod) % mod + mod) % mod;
}
return mp3[n] = ans;
}
inline ll get_f_sum(int n, int cnt){
if (n <= limit) return f_sum[n];
if (mp1.count(n)) return mp1[n];
ll ans = get_g_sum(n, cnt);
for (register int i = 2, j; i <= n; i = j + 1){
int tn = n / i;
j = n / tn;
ans = ((ans - get_f_sum(tn, cnt) * (get_h_sum(j, cnt) - get_h_sum(i - 1, cnt)) % mod) % mod + mod) % mod;
}
return mp1[n] = ans;
}
inline ll solve(int n, int m){
int cnt = init2(m);
mp1.clear();
mp2.clear();
mp3.clear();
return (get_f_sum(n, cnt) + 1) * inv2 % mod;
}
int main(){
int n, k1, k2;
cin >> n >> k1 >> k2;
limit = pow(n, 2.0 / 3.0);
init1();
cout << solve(n, k1) * solve(n, k2) % mod;
return 0;
}