P8193 [USACO22FEB] Sleeping in Class P

· · 题解

P8193 [USACO22FEB] Sleeping in Class P

提供一个不需要 Pollard - rho 的做法。

a 的前缀和数组为 s,容易发现若 q \nmid s_n 则无解,否则答案为 \left(\dfrac {s_n} q - 1\right) + (n - 1) - 2\sum\limits_{i = 1} ^ {n - 1} [q\mid s_i]。这是因为我们考虑将所有数全部合并再分裂,能够整除 q 的前缀和对应的切割点处不用分裂再合并,可以减少两步。

考虑这样一个朴素的暴力:对每个 s_n 的因数均求一遍答案。时间复杂度 \mathcal{O}(nd(s_n))。由于 s_n 可达 10 ^ {18},且 10 ^ {18} 内最大的约数个数级别为 10 ^ 5,所以无法接受。

注意到实际上只需求 2\sum\limits_{i = 1} ^ {n - 1} [q \mid s_i]。考虑每个 s_is_n 的每个因数 d 的贡献。显然,s_i 对所有 \gcd(s_i, s_n) 的因数有贡献。设 D = \gcd(s_i, s_n) = \prod p_i ^ {c_i},则任意 \prod\limits_{0\leq b_i \leq c_i} p_i ^ {b_i} 均需要加上 1。这是高维前缀和的形式,因此这一部分容易做到 n\log (s_n) + d(s_n)\omega(s_n)\log d(s_n),后面带一个 \log 是因为要把所有因子用 map 映射到一个较小的范围内,否则无法存储结果。哈希表如 gp_hash_table 应该可以去掉这个 \log

问题只剩下分解质因数了。不想写 Pollard - rho 怎么办?考虑只分解到 10 ^ 6,即用 10 ^ 6 以内的质因子试除 s_n。如果最终结果 \leq 10 ^ {12},说明剩下的 s_n 是一个质数(否则就被试除掉了),否则说明 s_n 有两个大于 10 ^ 6 的质因子(不一定不同)。对于前者,我们已经成功分解质因数。对于后者,注意到此时 s_n 的因子个数不超过 4 \times \max\limits_{i \leq 10 ^ 6}d(i),于是直接暴力即可。

时间复杂度 n\max\limits_{i \leq 10 ^ 6} d(i) + n\log (s_n) + d(s_n)\omega(s_n)\log d(s_n)。代码并不是很可读。

#include <bits/stdc++.h>
using namespace std;

#define ll long long
const int N = 2e5 + 5;
ll n, q, a[N], pr[N];
int pw[N], cnt, ppw[N];
map <ll, ll> mp;
int f[N];
int calc(int *pw) {
    int res = 0;
    for(int i = 1; i <= cnt; i++) res += pw[i] * ppw[i];
    return res;
}
ll rev(int x) {
    ll res = 1;
    for(int i = 1; i <= cnt; i++) {
        int d = x / ppw[i] % (pw[i] + 1);
        for(int j = 1; j <= d; j++) res *= pr[i];
    } return res;
}

int fix, cpw[N];
void dfs(int id) {
    if(id > cnt) {
        int cur = calc(cpw);
        f[cur - ppw[fix]] += f[cur];
        return;
    } for(int i = pw[id]; i >= (id == fix); i--) cpw[id] = i, dfs(id + 1);
}
void check() {
    ll tmp = a[n];
    for(int i = 2; i <= 1e6; i++)
        if(tmp % i == 0) {
            pr[++cnt] = i;
            while(tmp % i == 0) pw[cnt]++, tmp /= i;
        }
    if(tmp > 1e12) return;
    if(tmp > 1) pr[++cnt] = tmp, pw[cnt] = 1;
    ppw[cnt] = 1;
    for(int i = cnt - 1; ~i; i--) ppw[i] = ppw[i + 1] * (pw[i + 1] + 1);
    for(int i = 1; i < n; i++) {
        ll tmp = a[i];
        for(int j = 1; j <= cnt; j++) {
            int cur = 0;
            while(tmp % pr[j] == 0) cur++, tmp /= pr[j];
            cpw[j] = min(pw[j], cur);
        } f[calc(cpw)]++;
    }
    for(int i = 1; i <= cnt; i++) fix = i, dfs(1);
    for(int i = 0; i < ppw[0]; i++) mp[rev(i)] = n - 1 + a[n] / rev(i) - 1 - f[i] * 2;
}
int main() {
    cin >> n;
    for(int i = 1; i <= n; i++) cin >> a[i], a[i] += a[i - 1];
    check(), cin >> q;
    for(int i = 1; i <= q; i++) {
        ll x; cin >> x;
        if(a[n] % x) {puts("-1"); continue;}
        if(!mp[x]) {
            ll ans = 0;
            for(int j = 1; j < n; j++) ans += a[j] % x == 0;
            mp[x] = n - 1 + a[n] / x - 1 - ans * 2;
        }
        cout << mp[x] << "\n";
    }
    return 0;
}

/*
stupid mistakes:
    进制转换写错了
    rev : res *= pr[i] -> *= pr[j]
*/