P7328 「MCOI-07」Dream and Machine Learning

· · 题解

P7328 「MCOI-07」Dream and Machine Learning

给出一个 不依赖于 b\in \{2, 3\},且不太依赖于 n 的大小的做法。

注意到 e_i 随机给出,我们尽可能利用这个性质。同时,注意到若 e_a + e_b = e_c + e_d,则 r_ar_b\equiv r_cr_d\pmod m,这说明 r_ar_b - r_cr_dm 的倍数。

我们求出若干组这样的 a, b, c, d,并对所有 |r_ar_b - r_cr_d|\gcd 即可得到 m。剩下来是 trivial 的。

由于 r_ar_b - r_cr_d\mathcal{O}(m ^ 2) 级别,所以期望进行很少次 \gcd 即可得到正确的 m。实际测试中只需要 c = 10 次即可。

接下来考虑如何求出 a, b, c, d。很经典的 meet-in-the-middle 套路。在时间限制范围内将足够多的 (a, b) 塞入哈希表(设为 B 个),key 值为 e_a + e_b。直接枚举 c, d 并检查 e_c + e_d 是否在哈希表中出现,若出现且对应的 value (a, b)\neq (c, d),则可以进行一次 gcd。枚举四元组的时间复杂度期望为 c\times \dfrac {\max e_i}{B}

笔者实现后发现时间瓶颈在于高精度快速幂,如果使用 FFT 优化和压位优化应该能跑得更快。求解 m 并不是瓶颈。

接下来说说这个做法相较于 std 可以扩展的地方。

首先它不依赖于 b,所以这个 b 可以出到 \mathcal{O}(m) 级别。

其次,它不太依赖于 n 的级别。当 n 变小的时候,我们仍然可以将足够多的 T 元组扔进哈希表,并在与 n 很大时相同的复杂度得到足够多的 2T 元组求解 gcd。

唯一的问题在于这会使得 \prod_{i\in X} r_i - \prod_{i\in Y} r_i 变得很大,级别为 \mathcal{O}(m ^ T),不过这个问题其实并不严重,因为这只给复杂度乘了常数,或者说给求解 m 的复杂度乘上了 T ^ 3\approx \log_n ^ 3 \max e_i(求解 \gcd 的次数乘上 T,乘法复杂度乘上 T ^ 2)。

综上,本题可以加强至 b 为任意 m 以内的数,n = 100,严格优于使用生日悖论的方法。而且不算难写(除了高精度板子,和 binary GCD)。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
const int L = 600;
bool Mbe;
struct BigInt {
    int len;
    short a[L];
    BigInt() {memset(a, 0, L << 1), len = 0;} // ADD len = 0
    void print() {if(len) for(int i = len - 1; ~i; i--) putchar(a[i] + '0'); putchar('\n');}
    void read() {
        char s = getchar();
        while(!isdigit(s)) s = getchar();
        while(isdigit(s)) a[len++] = s - '0', s = getchar();
        reverse(a, a + len);
    }
    void flush() {
        len += 2;
        for(int i = 0; i < len; i++) a[i + 1] += a[i] / 10, a[i] %= 10;
        while(len && !a[len - 1]) len--;
    }
    void shrink() {
        len += 2;
        for(int i = 0; i < len; i++) if(a[i] < 0) a[i + 1]--, a[i] += 10;
        while(len && !a[len - 1]) len--;
    }
    bool operator < (BigInt rhs) {
        if(len != rhs.len) return len < rhs.len;
        for(int i = len - 1; ~i; i--) if(a[i] != rhs.a[i]) return a[i] < rhs.a[i];
        return 0;
    }
    BigInt operator + (BigInt rhs) {
        BigInt res;
        res.len = max(len, rhs.len);
        for(int i = 0; i < res.len; i++) res.a[i] = a[i] + rhs.a[i]; // len -> res.len
        res.flush();
        return res;
    }
    BigInt operator - (BigInt rhs) {
        BigInt res;
        res.len = max(len, rhs.len);
        for(int i = 0; i < res.len; i++) res.a[i] = a[i] - rhs.a[i];
        res.shrink();
        return res;
    }
    BigInt operator * (BigInt rhs) {
        BigInt res;
        res.len = len + rhs.len;
        for(int i = 0; i < len; i++)
            for(int j = 0; j < rhs.len; j++)
                res.a[i + j] += a[i] * rhs.a[j];
        res.flush();
        return res;
    }
    BigInt operator * (int v) {
        BigInt res;
        res.len = len; // ADD THIS LINE
        for(int i = 0; i < len; i++) res.a[i] = a[i] * v;
        res.flush();
        return res;
    }
    BigInt operator / (int v) {
        BigInt res;
        res.len = len;
        int rem = 0;
        for(int i = len - 1; ~i; i--) rem = rem * 10 + a[i], res.a[i] = rem / v, rem %= v;
        while(res.len && !res.a[res.len - 1]) res.len--;
        return res;
    }
    BigInt operator % (BigInt rhs) {
        BigInt rem;
        for(int i = len - 1; ~i; i--) {
            rem = rem * 10, rem.a[0] += a[i], rem.flush();
            while(!(rem < rhs)) rem = rem - rhs;
        }
        return rem;
    }
} f[N], m, pw[32];
BigInt gcd(BigInt x, BigInt y) {
    int pw = 0;
    while(x.len && x.a[0] % 2 == 0 && y.len && y.a[0] % 2 == 0) pw++, x = x / 2, y = y / 2;
    while(x.len && y.len) {
        while(x.a[0] % 2 == 0) x = x / 2;
        while(y.a[0] % 2 == 0) y = y / 2;
        if(x < y) swap(x, y);
        x = x - y;
    }
    if(x < y) swap(x, y);
    while(pw--) x = x * 2;
    return x;
}
int cnt = 20, b, n, q, e[N];
map <int, pair <int, int>> mp;
mt19937 rnd(time(0));
int rd(int l, int r) {return rnd() % (r - l + 1) + l;}
BigInt ksm(int b) {
    BigInt s;
    s.len = s.a[0] = 1;
    int cnt = 0;
    while(b) {
        if(b & 1) s = s * pw[cnt] % m;
        b >>= 1, cnt++;
    }
    return s;
}
void solve() {
    pw[0].len = 1, pw[0].a[0] = b;
    for(int i = 1; i < 30; i++) pw[i] = pw[i - 1] * pw[i - 1] % m;
    while(q--) {
        int a;
        scanf("%d", &a);
        ksm(a).print();
    }
    exit(0);
}
bool Med;
int main() {
    fprintf(stderr, "%.4lf\n", (&Med - &Mbe) / 1048576.0);
    // freopen("sample.in", "r", stdin);
    cin >> b >> n >> q;
    for(int i = 1; i <= n; i++) scanf("%d", &e[i]), f[i].read();
    for(int i = 1; i <= 1e5; i++) {
        int a = rd(1, n), b = rd(1, n);
        mp[e[a] + e[b]] = make_pair(a, b);
    }
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++) {
            auto it = mp.find(e[i] + e[j]);
            if(it == mp.end()) continue;
            int a = it -> second.first, b = it -> second.second;
            if(a == i || b == i) continue;
            BigInt x = f[a] * f[b], y = f[i] * f[j];
            if(x < y) swap(x, y);
            m = gcd(m, x - y);
            if(!--cnt) solve();
        }
    return 0;
}