题解 P7504 【「HMOI R1」可爱的德丽莎】

· · 题解

前置芝士:莫比乌斯反演、杜教筛

本题解省略部分分做法及其代码。

solve(n, m) = \displaystyle\sum_{i = 1}^n [\gcd(i, m) = 1] \sum_{j = 1}^i [\gcd(i, j) = 1] j,则原式 = solve(n, k1) solve(n, k2)

solve(n, m) = \displaystyle\sum_{i = 1}^n [\gcd(i, m) = 1] \sum_{j = 1}^i j \sum_{d \mid \gcd(i, j)} \mu(d) = \displaystyle\sum_{i = 1}^n [\gcd(i, m) = 1] \sum_{d \mid i} \mu(d) \sum_{d \mid j}^i j

S_1(n) = \displaystyle\sum_{i = 1}^n i,则:

solve(n, m) = \displaystyle\sum_{i = 1}^n [\gcd(i, m) = 1] \sum_{d \mid i} \mu(d) d S_1(\frac{i}{d}) = \frac{\displaystyle\sum_{i = 1}^n [\gcd(i, m) = 1] i \sum_{d \mid i} \mu(d) (\frac{i}{d} + 1)}{2} = \frac{\displaystyle\sum_{i = 1}^n [\gcd(i, m) = 1] i (\varphi(i) + \epsilon(i))}{2} = \frac{\displaystyle\sum_{i = 1}^n [\gcd(i, m) = 1] i \varphi(i) + 1}{2}

到此为止,问题转变为了快速求出 f(n) = [\gcd(n, m) = 1] n \varphi(n) 的前缀和。

想到 id_2 = id * (id \cdot \varphi)\epsilon(\gcd) 为完全积性函数,发现 \epsilon(\gcd) \cdot id_2 = (\epsilon(\gcd) \cdot id) * (\epsilon(\gcd) \cdot id \cdot \varphi),考虑杜教筛。

显然 \displaystyle\sum_{i = 1}^n [\gcd(i, m) = 1] i = \displaystyle\sum_{d \mid m} \mu(d) d S_1(\lfloor \frac{n}{d} \rfloor)\epsilon(\gcd) \cdot id_2 的前缀和同理。至此可以杜教筛。时间复杂度为 O(\sqrt{n} (\tau(k_1) + \tau(k_2)) + n^{\frac{2}{3}})

代码:

#include <iostream>
#include <map>
#include <cmath>

using namespace std;

typedef long long ll;

const int N = 1587401 + 7, M = 1536 + 7, mod = 998244353, inv2 = 499122177, inv6 = 166374059;
int limit;
int prime[N], phi[N], divisor[M], mu_list[M];
ll f[N], g[N], h[N], f_sum[N], g_sum[N], h_sum[N];
bool p[N];
map<int, ll> mp1, mp2, mp3;

inline void init1(){
    int cnt = 0;
    p[0] = p[1] = true;
    phi[1] = 1;
    for (register int i = 2; i <= limit; i++){
        if (!p[i]){
            prime[++cnt] = i;
            phi[i] = i - 1;
        }
        for (register int j = 1; j <= cnt && i * prime[j] <= limit; j++){
            int t = i * prime[j];
            p[t] = true;
            if (i % prime[j] == 0){
                phi[t] = phi[i] * prime[j];
                break;
            }
            phi[t] = phi[i] * (prime[j] - 1);
        }
    }
}

inline int mu(int n){
    int ans = 1;
    for (register int i = 2; i * i <= n; i++){
        if (n % i == 0){
            n /= i;
            if (n % i == 0) return 0;
            ans = -ans;
        }
    }
    if (n > 1) ans = -ans;
    return ans;
}

inline int init2(int n){
    int prime_cnt = 0, divisor_cnt = 0, t = sqrt(n);
    p[0] = p[1] = true;
    f[1] = 1;
    g[1] = 1;
    h[1] = 1;
    for (register int i = 2; i <= limit; i++){
        if (!p[i]){
            prime[++prime_cnt] = i;
            f[i] = n % i == 0 ? 0 : (ll)i * phi[i] % mod;
            g[i] = n % i == 0 ? 0 : (ll)i * i % mod;
            h[i] = n % i == 0 ? 0 : i;
        }
        for (register int j = 1; j <= prime_cnt && i * prime[j] <= limit; j++){
            int t = i * prime[j];
            p[t] = true;
            f[t] = f[i] != 0 && n % prime[j] != 0 ? (ll)t * phi[t] % mod : 0;
            g[t] = g[i] * g[prime[j]] % mod;
            h[t] = h[i] * h[prime[j]] % mod;
            if (i % prime[j] == 0) break;
        }
    }
    for (register int i = 1; i < N; i++){
        f_sum[i] = (f_sum[i - 1] + f[i]) % mod;
        g_sum[i] = (g_sum[i - 1] + g[i]) % mod;
        h_sum[i] = (h_sum[i - 1] + h[i]) % mod;
    }
    for (register int i = 1; i <= t; i++){
        if (n % i == 0){
            divisor_cnt++;
            divisor[divisor_cnt] = i;
            mu_list[divisor_cnt] = mu(i);
            if (i * i != n){
                int tn = n / i;
                divisor_cnt++;
                divisor[divisor_cnt] = tn;
                mu_list[divisor_cnt] = mu(tn);
            }
        }
    }
    return divisor_cnt;
}

inline ll sum2(int n){
    return (ll)n * (n + 1) % mod * ((ll)n * 2 % mod + 1) % mod * inv6 % mod;
}

inline ll get_g_sum(int n, int cnt){
    if (n <= limit) return g_sum[n];
    if (mp2.count(n)) return mp2[n];
    ll ans = 0;
    for (register int i = 1; i <= cnt; i++){
        ans = ((ans + (ll)mu_list[i] * divisor[i] % mod * divisor[i] % mod * sum2(n / divisor[i]) % mod) % mod + mod) % mod;
    }
    return mp2[n] = ans;
}

inline ll sum1(int n){
    return (ll)n * (n + 1) / 2 % mod;
}

inline ll get_h_sum(int n, int cnt){
    if (n <= limit) return h_sum[n];
    if (mp3.count(n)) return mp3[n];
    ll ans = 0;
    for (register int i = 1; i <= cnt; i++){
        ans = ((ans + mu_list[i] * divisor[i] % mod * sum1(n / divisor[i]) % mod) % mod + mod) % mod;
    }
    return mp3[n] = ans;
}

inline ll get_f_sum(int n, int cnt){
    if (n <= limit) return f_sum[n];
    if (mp1.count(n)) return mp1[n];
    ll ans = get_g_sum(n, cnt);
    for (register int i = 2, j; i <= n; i = j + 1){
        int tn = n / i;
        j = n / tn;
        ans = ((ans - get_f_sum(tn, cnt) * (get_h_sum(j, cnt) - get_h_sum(i - 1, cnt)) % mod) % mod + mod) % mod;
    }
    return mp1[n] = ans;
}

inline ll solve(int n, int m){
    int cnt = init2(m);
    mp1.clear();
    mp2.clear();
    mp3.clear();
    return (get_f_sum(n, cnt) + 1) * inv2 % mod;
}

int main(){
    int n, k1, k2;
    cin >> n >> k1 >> k2;
    limit = pow(n, 2.0 / 3.0);
    init1();
    cout << solve(n, k1) * solve(n, k2) % mod;
    return 0;
}