题解 P5668 【【模板】N 次剩余】

· · 题解

Update on 2023.9.8:已通过 Hack 数据。

这里的 m 没有特殊性质,考虑将其分解为 m = \displaystyle\prod_{i = 1}^{\omega(m)} p_i^{q_i},并在依次求出每个 x^n \equiv k \pmod{p_i^{q_i}} 的所有解后暴力合并。

问题转化为求出形如 x^n \equiv k \pmod{p^q} 的方程的所有解。

  1. k \bmod p^q = 0

此时设 x = ap^b, k = p^ca \bmod p \neq 0 \operatorname{and} b, c \in Z),则 bn \geq cb \geq \lceil \frac{c}{n} \rceil,于是此时 x 可以取遍 p^q 以内所有包含 0 在内的 p^{\lceil \frac{c}{n} \rceil} 的倍数。直接暴力枚举即可。

  1. k \bmod p^q \neq 0 , k \bmod p = 0

此时设 x = ap^b, k = cp^da, c \bmod p \neq 0 \operatorname{and} b, d \in Z),则 bn = da^n \equiv c \pmod{p^{q - d}}。此时我们可以直接解出 b = \frac{d}{n},求余下那个方程交给接下来要讲的两种情况。

考虑解出所有 a 后的操作。对于每个符合条件的 at \in Z,令 a' = a + tp^{q - d},则所有 x = a'p^b < p^q 都是可行的解。

  1. k \bmod p \neq 0, p \neq 2 \operatorname{or} q \leq 2

此时求出 p^q 的一个原根 g,设 x = g^a, k = g^b,则 na \equiv b \pmod{\varphi(p^q)}。在求出原根后,b 可以用 BSGS 解出,a 可以用 exgcd 解线性同余方程得出。

  1. k \bmod p \neq 0, p = 2, q > 2

这里给出一个结论:

证明:首先不难归纳得出 5^{2^{q - 3}} \equiv 2^{q - 1} + 1 \pmod{2^q},于是 \operatorname{ord}_{2^q}(5) > 2^{q - 3};又因为 (2^{q - 1} + 1)^2 \equiv 1 \pmod{2^q},则 \operatorname{ord}_{2^q}(5) \mid 2^{q - 2};可得 \operatorname{ord}_{2^q}(5) = 2^{q - 2},又因为一定 \nexists b_1 \neq b_2, 5^{b_1} + 5^{b_2} \equiv 0 \pmod{2^q},则表示方式唯一。

于是我们 BSGS 解出 k 的表示方式,exgcd 求出 x 的所有可能表示方式即可。

时间复杂度看起来很高,但由于这道题保证每个质因数幂次对应的解数和所有询问的总解数 \leq 10^6,实际上可以通过。

需要注意的是,有可能某一个质因数幂次对应的方程无解但其他质因数幂次对应的方程解数都很大,此时需要提前判断是否存在一个质因数幂次对应的方程无解。

代码:

#include <iostream>
#include <algorithm>
#include <map>
#include <vector>
#include <cstdio>
#include <cmath>

using namespace std;

typedef long long ll;

typedef struct {
    int mod;
    vector<ll> v;

    inline void clear(){
        mod = 1;
        v.clear();
    }
} Equation;

map<int, int> mp1;
map<ll, int> mp2;
vector<Equation> v;

int gcd(int a, int b){
    return b == 0 ? a : gcd(b, a % b);
}

inline int lcm(int a, int b){
    return a * b / gcd(a, b);
}

void exgcd(ll a, ll b, ll &x, ll &y){
    if (b == 0){
        x = 1;
        y = 0;
        return;
    }
    ll t;
    exgcd(b, a % b, x, y);
    t = x;
    x = y;
    y = t - a / b * y;
}

Equation operator +(Equation &a, Equation &b){
    int sizea = a.v.size(), sizeb = b.v.size();
    ll d = gcd(a.mod, b.mod), x, y, t1 = a.mod / d;
    Equation ans;
    ans.mod = lcm(a.mod, b.mod);
    exgcd(a.mod, b.mod, x, y);
    for (register int i = 0; i < sizea; i++){
        for (register int j = 0; j < sizeb; j++){
            ll t2 = ((a.v[i] - b.v[j]) % a.mod + a.mod) % a.mod;
            if (t2 % d == 0) ans.v.push_back(((b.v[j] + b.mod * (y * (t2 / d) % t1) % ans.mod) % ans.mod + ans.mod) % ans.mod);
        }
    }
    sort(ans.v.begin(), ans.v.end());
    ans.v.erase(unique(ans.v.begin(), ans.v.end()), ans.v.end());
    return ans;
}

Equation operator +=(Equation &a, Equation &b){
    return a = a + b;
}

inline int quick_pow(int x, int p){
    int ans = 1;
    while (p){
        if (p & 1) ans *= x;
        x *= x;
        p >>= 1;
    }
    return ans;
}

inline int euler(int n){
    int ans = n;
    for (register int i = 2; (ll)i * i <= n; i++){
        if (n % i == 0){
            ans = ans / i * (i - 1);
            while (n % i == 0){
                n /= i;
            }
        }
    }
    if (n > 1) ans = ans / n * (n - 1);
    return ans;
}

inline void decompound(int n){
    mp1.clear();
    for (register int i = 2; i * i <= n; i++){
        while (n % i == 0){
            n /= i;
            mp1[i]++;
        }
    }
    if (n > 1) mp1[n] = 1;
}

inline ll quick_pow(ll x, ll p, ll mod){
    ll ans = 1;
    while (p){
        if (p & 1) ans = ans * x % mod;
        x = x * x % mod;
        p >>= 1;
    }
    return ans;
}

inline int get_least_primitive_root(int n){
    int phi_n = euler(n);
    decompound(phi_n);
    for (register int i = 0; i < n; i++){
        if (gcd(i, n) > 1) continue;
        bool flag = true;
        for (register map<int, int>::iterator j = mp1.begin(); j != mp1.end(); j++){
            if (quick_pow(i, phi_n / j->first, n) == 1){
                flag = false;
                break;
            }
        }
        if (flag) return i;
    }
    return -1;
}

inline ll inv(ll a, ll b){
    ll x, y;
    exgcd(a, b, x, y);
    return (x % b + b) % b;
}

inline int bsgs(int a, int b, int p){
    if (p == 1) return 0;
    a %= p;
    b %= p;
    if (b == 1) return 0;
    if (a == 0) return b == 0 ? 1 : -1;
    int n = ceil(sqrt(euler(p))), i = 0;
    ll t = quick_pow(a, n, p);
    mp2.clear();
    for (register ll j = b; i < n; i++, j = j * a % p){
        mp2[j] = i;
    }
    i = 1;
    for (register ll j = t; i <= n; i++, j = j * t % p){
        if (mp2.count(j)) return i * n - mp2[j];
    }
    return -1;
}

inline Equation solve1(int a, int b, int p){
    int g = get_least_primitive_root(p), c = bsgs(g, b, p), phi_p = euler(p), d = gcd(a, phi_p);
    Equation ans;
    ans.mod = p;
    if (c % d != 0) return ans;
    int t = phi_p / d;
    ll e = quick_pow(g, t, p), x, y, z;
    exgcd(a, phi_p, x, y);
    x = (x * (c / d) % t + t) % t;
    z = quick_pow(g, x, p);
    while (x < phi_p){
        ans.v.push_back(z);
        x += t;
        z = z * e % p;
    }
    sort(ans.v.begin(), ans.v.end());
    return ans;
}

inline vector<ll> solve2(int a, int b, int p){
    int d = gcd(a, p);
    if (b % d != 0) return vector<ll>();
    int t = p / d;
    ll x, y;
    vector<ll> ans;
    exgcd(a, p, x, y);
    x = (x * (b / d) % t + t) % t;
    while (x < p){
        ans.push_back(x);
        x += t;
    }
    return ans;
}

inline Equation solve3(int a, int b, int p, int p_pow_k){
    if (p > 2 || p_pow_k <= 4) return solve1(a, b, p_pow_k);
    int x = bsgs(5, b, p_pow_k), y, size1, size2;
    vector<ll> v1, v2;
    Equation ans;
    ans.mod = p_pow_k;
    if (x != -1){
        y = 0;
    } else {
        y = 1;
        x = bsgs(5, p_pow_k - b % p_pow_k, p_pow_k);
    }
    v1 = solve2(a, y, 2);
    v2 = solve2(a, x, p_pow_k / 4);
    size1 = v1.size();
    size2 = v2.size();
    for (register int i = 0; i < size1; i++){
        if (v1[i] == 0){
            v1[i] = 1;
        } else {
            v1[i] = p_pow_k - 1;
        }
    }
    for (register int i = 0; i < size2; i++){
        v2[i] = quick_pow(5, v2[i], p_pow_k);
    }
    for (register int i = 0; i < size1; i++){
        for (register int j = 0; j < size2; j++){
            ans.v.push_back(v1[i] * v2[j] % p_pow_k);
        }
    }
    sort(ans.v.begin(), ans.v.end());
    return ans;
}

void write(int n){
    if (n >= 10) write(n / 10);
    putchar(n % 10 + '0');
}

void write(ll n){
    if (n >= 10) write(n / 10);
    putchar(n % 10 + '0');
}

int main(){
    int t;
    cin >> t;
    for (register int i = 1; i <= t; i++){
        int n, m, k, ansa;
        bool flag = false;
        Equation ansb;
        cin >> n >> m >> k;
        v.clear();
        for (register int j = 2; j <= m; j++){
            if (j * j > m){
                if (m == 1) break;
                j = m;
            }
            if (m % j == 0){
                int power = 0;
                Equation equation;
                equation.clear();
                while (m % j == 0){
                    m /= j;
                    equation.mod *= j;
                    power++;
                }
                if (k % equation.mod == 0){
                    int x = quick_pow(j, (power - 1) / n + 1);
                    for (register int y = 0; y < equation.mod; y += x){
                        equation.v.push_back(y);
                    }
                } else if (k % j == 0){
                    int u = k, v = 1, w = 0;
                    while (u % j == 0){
                        u /= j;
                        v *= j;
                        w++;
                    }
                    if (w % n != 0){
                        flag = true;
                        break;
                    }
                    int size;
                    ll mul = quick_pow(j, w / n, equation.mod);
                    Equation t = solve3(n, k / v, j, equation.mod / v);
                    size = t.v.size();
                    for (register int x = 0; x < size; x++){
                        for (register int y = 0; ; y++){
                            ll val = mul * (t.v[x] + y * t.mod);
                            if (val >= equation.mod) break;
                            equation.v.push_back(val);
                        }
                    }
                } else {
                    equation = solve3(n, k, j, equation.mod);
                }
                if (equation.v.empty()){
                    flag = true;
                    break;
                }
                v.push_back(equation);
            }
        }
        if (flag){
            write(0);
            putchar('\n');
        } else {
            int size = v.size();
            ansb.clear();
            ansb.v.push_back(0);
            for (register int j = 0; j < size; j++){
                ansb += v[j];
            }
            ansa = ansb.v.size();
            write(ansa);
            putchar('\n');
            if (ansa != 0){
                for (register int j = 0; j < ansa; j++){
                    write(ansb.v[j]);
                    putchar(' ');
                }
                putchar('\n');
            }
        }
    }
    return 0;
}