Solution to CF1749D Counting Arrays

· · 题解

Observe that an arbitrary array a always has a removal sequence of [1, 1, \dots, 1], for \gcd(a_1, 1) = 1 always holds. Therefore, as long as there exists an index 2 \le i \le n, 2\le j \le i, \gcd(a_i, j) = 1, we can obtain another removal sequence by removing the first element until a_i is shifted to position j, removing the j-th element, and removing the first element until every element is removed.

Thus, the problem reduces to counting the number of arrays of length from 1 to n, where each a_i is an integer from 1 to m, satisfying the aforementioned condition.

Consider each length of a separately.

Assume that the length of a is |a|, the number of arrays satisfying the condition equals m^{|a|} minus the number of arrays a such that

\forall 2\le i\le |a|, 2\le j \le i,\gcd(a_i, j) > 1

Let T denote the number of arrays satisfying the condition above. To compute T, we consider each a_i separately.

Let c_i be the number of integers which a_i can be. It is obvious that c_1 = m, because there is no constraint on it.

Since for every 2\le j\le i, \gcd(a_i, j) > 1, a_i must be a multiple of every prime number less than or equal to i. Let d = \prod\limits_{2\le p\le i}p where p is a prime number. We have c_i = \left\lfloor\dfrac{m}{d}\right\rfloor.

The condition holds for every 2\le i \le |a|, so T = \prod\limits_{i=1}^{|a|} c_i.

The answer is the sum of m^{|a|} - T over all 1 \le |a| \le n.

Time complexity: \Theta(n \log n) due to the use of exponentiation by squaring.

:::success[Implementation]

#include <bits/stdc++.h>

using namespace std;

using ll = long long;
using pii = pair<int, int>;

#ifdef ONLINE_JUDGE
#define debug(...) 0
#else
#define debug(...) fprintf(stderr, __VA_ARGS__), fflush(stderr)
#endif

constexpr int N = 3e5 + 5, mod = 998244353;

int qpow(int a, int b) {
    a %= mod;
    int res = 1;
    while (b) {
        if (b & 1) res = 1ll * res * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return res;
}

int n; ll m;
ll c[N];

bool notprime[N];
vector<int> primes;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;

    for (int i = 2; i <= n; i++) {
        if (!notprime[i]) primes.emplace_back(i);
        for (int j = 0; primes[j] * i <= n; j++) {
            notprime[primes[j] * i] = 1;
            if (i % primes[j] == 0) break;
        }
    }

    ll lc = 1;

    for (int i = 2; i <= n; i++) {
        if (!notprime[i]) {
            if (lc <= m / i) lc *= i;
            else lc = m + 1;
        }
        c[i] = m / lc;
    }

    int ans = 0, prod = m % mod;
    for (int i = 2; i <= n; i++) {
        prod = c[i] % mod * prod % mod;
        ans = ((ans + qpow(m % mod, i)) % mod - prod + mod) % mod;
    }

    cout << ans << "\n";
    return 0;
}

:::