题解 P5330 【[SNOI2019]数论】

· · 题解

题意

T 以内模 P 值属于集合 A 且 模 B 值属于集合 B 的数的数量。

T\le 10^{18},P,Q,|A|,|B|\le 10^6

首先对一个模型作介绍。

对于 [0,Q) 中的数都建立一个点,若点 i 连一条边至 (i+P)\mod Q,那么整个图会分为 \gcd(P,Q) 个独立的环,每个环的点数是 \dfrac{\operatorname{lcm(P,Q)}}{P}

说明:

每一个点走若干个 P 步回到自己,同时经过了若干个 Q 总长,因此经过的总长是 \operatorname{lcm(P,Q)},点数 \dfrac{\operatorname{lcm(P,Q)}}{P}。而总点数 Q,所以环数 Q/(\dfrac{\operatorname{lcm(P,Q)}}{P})=\gcd(P,Q)

假设 P \le Q。因为 [0,P) 中的数要不变地唯一对应到 [0,Q) 中的一个数。

在这个图上考虑问题,我们可以枚举 a_i,在图上直接找到它,考虑它沿边走 k 步得到的数是 (a_i + kP)\mod Q,我们需要计算包括自己,在 T-1 以内,一共经过多少 b_i

要解决的问题直观了起来。我们可以在图上标记所有 b_i 权值为 1 其余为 0,预处理环上总和与 1 圈以内的前缀和。

对于 a_i,包括自己一共要统计 \lfloor\dfrac{T-1-a_i}{P}\rfloor+1 个点。首先直接计算整圈,然后对于留下的不满一圈的路径,分类讨论一下计算即可。

要注意的是:前缀和中选定的起点值是 0 还是总和;t-1-a_i<0 要直接舍弃。

#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>

typedef long long ll;
const int N = 1000001;

inline ll read() {
    ll x = 0, f = 1; char ch = getchar();
    while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
    do x = x * 10 + ch - 48, ch = getchar(); while(ch >= '0' && ch <= '9');
    return x * f;
}

int gcd(int x,int y) { return !y ? x : gcd(y,x % y); }

int P,Q,n,m,a[N],b[N]; ll t;
int vis[N],at[N]; ll dist[N],ring[N],sum[N],len;

ll Prefix(ll a,ll x) {
    if(!x) return 0;
    ll b = (a + (x - 1) * P) % Q;
    if(x + dist[a] > len) return sum[a] - (ring[a] - ring[b] - at[a]);
    else return ring[b] - ring[a] + at[a];
}

ll Calc(ll T) {
    ll ans = 0;
    for(int i = 1;i <= n;i++) {
        if(T < a[i]) continue;
        ll st = (T - a[i]) / P + 1;
        ans += st / len * sum[a[i]] + Prefix(a[i],st % len);
    }
    return ans;
}

int main() {
    P = read(), Q = read(), n = read(), m = read(), t = read();
    if(P > Q) {
        std::swap(P,Q), std::swap(n,m);
        for(int i = 1;i <= m;i++) b[i] = read();
        for(int i = 1;i <= n;i++) a[i] = read();
    } else {
        for(int i = 1;i <= n;i++) a[i] = read();
        for(int i = 1;i <= m;i++) b[i] = read();
    }
    for(int i = 1;i <= m;i++) at[b[i]] = true;
    for(int i = 0;i < Q;i++) if(!vis[i]) {
        int p = i, k = (i + P) % Q;
        while(!vis[p]) {
            vis[p] = true, dist[k] = dist[p] + 1;
            ring[k] = ring[p] + at[k];
            p = k, k = (k + P) % Q;
        }
        k = (i + P) % Q, sum[i] = ring[i], ring[i] = dist[i] = 0;
        while(k != i) sum[k] = sum[i], k = (k + P) % Q;
    }
    len = Q / gcd(P,Q);
    std::printf("%lld\n",Calc(t - 1));
    return 0;
}