CF439E Devu and Birthday Celebration题解

· · 题解

做法:DP + 暴力容斥

#include <iostream>
#define int long long
using namespace std;
constexpr int MAXN(1000007);
constexpr int mod(1000000007);
int fct[MAXN], vis[MAXN], pr[MAXN], num[MAXN];
int q, cnt;
inline void read(int &temp) { cin >> temp; }
inline void DealFact() { fct[0] = 1;  for (int i(1); i <= 100000; ++i)  fct[i] = fct[i - 1] * i % mod; }
inline void DealPrime() {
    for (int i(2); i <= 100000; ++i) {
        if (!vis[i])  pr[++cnt] = i;
        for (int j(1); j <= cnt; ++j) {
            if (i * pr[j] > 100000)  break;
            vis[i * pr[j]] = 1;
            if (i % pr[j] == 0)  break;
        }
    }
}
inline void DealRc() {
    for (int a(1); a <= cnt; ++a) {
        num[pr[a]] = 1;
    for (int b(a + 1); b <= cnt; ++b) {
        if (pr[a] * pr[b] > 100000)  break;
        num[pr[a] * pr[b]] = 2;
    for (int c(b + 1); c <= cnt; ++c) {
        if (pr[a] * pr[b] * pr[c] > 100000)  break;
        num[pr[a] * pr[b] * pr[c]] = 3;
    for (int d(c + 1); d <= cnt; ++d) {
        if (pr[a] * pr[b] * pr[c] * pr[d] > 100000)  break;
        num[pr[a] * pr[b] * pr[c] * pr[d]] = 4;
    for (int e(d + 1); e <= cnt; ++e) {
        if (pr[a] * pr[b] * pr[c] * pr[d] * pr[e] > 100000)  break;
        num[pr[a] * pr[b] * pr[c] * pr[d] * pr[e]] = 5;
    for (int f(e + 1); f <= cnt; ++f) {
        if (pr[a] * pr[b] * pr[c] * pr[d] * pr[e] * pr[f] > 100000)  break;
        num[pr[a] * pr[b] * pr[c] * pr[d] * pr[e] * pr[f]] = 6;
    }}}}}}
}
inline int ksm(int base, int k) {
    int res(1);
    while (k) {
        if (k & 1)  res = res * base % mod;
        base = base * base % mod, k >>= 1;
    }
    return res;
}
inline int C(int n, int m) { return fct[n] * ksm(fct[n - m] * fct[m] % mod, mod - 2) % mod; }
inline int calc(int n, int k) {
    int res = C(n - 1, k - 1);
    for (int i(1); i * i <= n; ++i) {
        if (n % i == 0 && num[i] && n / i >= k) {
            if (num[i] % 2 == 1)  res = ((res - C(n / i - 1, k - 1)) % mod + mod) % mod;
            if (num[i] % 2 == 0)  res = (res + C(n / i - 1, k - 1)) % mod;
        }
        if (n % i == 0 && num[n / i] && i * i != n && i >= k) {
            if (num[n / i] % 2 == 1)  res = ((res - C(i - 1, k - 1)) % mod + mod) % mod;
            if (num[n / i] % 2 == 0)  res = (res + C(i - 1, k - 1)) % mod;
        }
    }
    return res;
}
signed main() {
    ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
    read(q);
    DealFact(), DealPrime(), DealRc();
    for (int i(1), x, y; i <= q; ++i) {
        read(x), read(y);
        cout << calc(x, y) << '\n';
    }
    return 0;
}