P12011 【MX-X10-T7】[LSOT-4] 春开,意遥遥。

· · 题解

这题作为 MX-X 最后一题算是很简单了,但是我怎么做了这么久?

很难不发现 (x, y)(a, b) 做乘法得到的二元组 (ay + bx, ax + by),它两项的和等于 (x + y) \times (a + b),后一项减前一项的差等于 (y - x) \times (b - a)。这启发我们用 (x + y, y - x) 来表示二元组,这样二元组的乘法就是对应的项相乘,同时两个二元组 (x, y), (a, b) 相同的条件仍然是 ay \equiv bx \pmod p。并且 (x + y, y - x)p 为奇素数时能和 (x, y) 构成双射。当 p = 2 时需要特判,若区间出现 (1, 0) 且不出现 (0, 0), (1, 1) 时答案为 2,否则为 1

因为 b_i 需要是正整数,所以 \prod a_i 如果有至少一项是 0,那么答案为 1。否则 \prod a_i 两项都 \ge 1。现在我们可以直接用 \frac{x}{y} 来替换二元组 (x, y),这样序列变成了一个正整数序列。

看到乘法考虑取离散对数,假设找到了 p 的一个原根 g,设 b_i 为最小的正整数使得 g^{b_i} \equiv a_i \pmod p。那么乘积变成了和,一个序列的答案即为 \frac{p - 1}{\gcd b_i}

但是对每个 a_i 求离散对数时间复杂度太大了。考虑求阶,设 c_i 为最小的正整数使得 a_i^{c_i} \equiv 1 \pmod p,那么 b_i = \frac{p - 1}{c_i}。一个序列的答案变为 \operatorname{lcm} c_i

现在问题变成求一个序列所有子区间的 \operatorname{lcm} 的和。考虑枚举子区间右端点,因为序列中所有数都是 p - 1 的因数,所以左端点变化时 \operatorname{lcm} 只会变化 O(\log V) 次。维护每种 \operatorname{lcm} 对应的左端点范围即可。

时间复杂度 O(\sqrt p + n \log^2 p)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef __int128 lll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 100100;
const ll mod = 1000000007;

ll n, P, a[maxn], b[maxn], c[maxn], f[maxn], m;

inline ll qpow(ll b, ll p, ll mod) {
    ll res = 1;
    while (p) {
        if (p & 1) {
            res = (lll)res * b % mod;
        }
        b = (lll)b * b % mod;
        p >>= 1;
    }
    return res;
}

int tot;
pii p[99], q[99];

inline ll calc(ll x) {
    if (x == 1) {
        return 1;
    }
    ll y = P - 1;
    for (int i = 1; i <= tot; ++i) {
        for (int j = 1; j <= p[i].scd; ++j) {
            if (qpow(x, y / p[i].fst, P) == 1) {
                y /= p[i].fst;
            } else {
                break;
            }
        }
    }
    return y;
}

inline ll work() {
    int tot = 1;
    ll ans = f[1] % mod;
    p[1] = mkp(1, f[1]);
    for (int i = 2; i <= m; ++i) {
        p[++tot] = mkp(i, f[i]);
        for (int j = 1; j <= tot; ++j) {
            p[j].scd = f[i] / __gcd(f[i], p[j].scd) * p[j].scd;
        }
        int nt = 1;
        q[1] = p[1];
        for (int j = 2; j <= tot; ++j) {
            if (p[j].scd != p[j - 1].scd) {
                q[++nt] = p[j];
            }
        }
        tot = nt;
        q[tot + 1].fst = i + 1;
        for (int j = 1; j <= tot; ++j) {
            p[j] = q[j];
            ans = (ans + q[j].scd % mod * (q[j + 1].fst - q[j].fst)) % mod;
        }
    }
    return ans;
}

void solve() {
    scanf("%lld%lld", &n, &P);
    if (P == 2) {
        int p = 0, q = 0;
        ll ans = n * (n + 1) / 2;
        for (int i = 1, x, y; i <= n; ++i) {
            scanf("%d%d", &x, &y);
            if (x == 1 && y == 0) {
                q = i;
            } else if (x == y) {
                p = i;
            }
            ans = (ans + max(q - p, 0)) % mod;
        }
        printf("%lld\n", ans % mod);
        return;
    }
    ll x = P - 1;
    for (ll i = 2; i * i <= x; ++i) {
        if (x % i == 0) {
            ll cnt = 0;
            while (x % i == 0) {
                x /= i;
                ++cnt;
            }
            p[++tot] = mkp(i, cnt);
        }
    }
    if (x > 1) {
        p[++tot] = mkp(x, 1);
    }
    for (int i = 1; i <= n; ++i) {
        ll x, y;
        scanf("%lld%lld", &x, &y);
        a[i] = (x + y) % P;
        b[i] = (y - x + P) % P;
        if (a[i] || b[i]) {
            c[i] = calc((lll)a[i] * qpow(b[i], P - 2, P) % P);
        }
    }
    ll ans = 0, sl = n * (n + 1) / 2;
    for (int i = 1, j = 1; i <= n; i = (++j)) {
        if (!a[i] || !b[i]) {
            continue;
        }
        while (j < n && a[j + 1] && b[j + 1]) {
            ++j;
        }
        m = 0;
        for (int k = i; k <= j; ++k) {
            f[++m] = c[k];
        }
        ll len = j - i + 1;
        sl -= len * (len + 1) / 2;
        ans = (ans + work()) % mod;
    }
    ans = (ans + sl) % mod;
    printf("%lld\n", ans);
}

int main() {
    int T = 1;
    // scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}