「Wdsr-2」白泽教育-题解

· · 题解

Subtask 1

前置芝士:大步小步算法(BSGS)

由高德纳箭号表示法的定义可得:a \uparrow b = a^b。又由于 Subtask 1 的保证 p 为质数,所以我们使用 BSGS 求解原方程即可。时间复杂度为 O(T \sqrt{p} \log p)

Subtask 2

前置芝士:扩展欧拉定理(exEuler)

我们首先来化简 a \uparrow \uparrow b

a \uparrow \uparrow b = a \uparrow (a \uparrow \uparrow (b - 1)) = a^{a \uparrow \uparrow (b - 1)} = a^{a^{a \uparrow \uparrow (b - 2)}} = \cdots 预处理欧拉函数的值,然后从 $0$ 开始枚举 $x$ 的值并判断即可。设 $k_1$ 表示使欧拉函数变为 $1$ 的最少嵌套层数 $+ 1$,则 $x$ 的枚举上界为 $k_1$。 这里计算 $a \uparrow \uparrow b$ 时需要用到扩展欧拉定理。但是,由于我们在计算 $a^b \bmod p$ 之类的式子时,需要判断 $b$ 与 $\varphi(p)$ 之间的大小关系,所以我使用了一个名为 Node 的结构体来保存运算结果取模后的值和这个结果取模前是否大于 $\varphi(p)$。时间复杂度为 $O(T \sqrt{p} \log p)$。 ### Subtask $3

我们首先来化简 a \uparrow^3 b

a \uparrow^3 b = a \uparrow \uparrow (a \uparrow^3 (b - 1))

由于欧拉函数嵌套多层后其值会变为 1,所以如果我们要算出 a \uparrow^3 b \bmod p 的值,只要算出 a \uparrow \uparrow \min(a \uparrow^3 (b - 1), k_1) \bmod p 即可。因此,a \uparrow^3 b 可以递归计算。

预处理欧拉函数的值,然后从 0 开始枚举 x 的值并判断即可。x 的枚举上界为满足 a \uparrow \uparrow (k_2 - 1) \geq k_1 的最小 k_2。时间复杂度为 O(T \sqrt{p} \log p)

具体细节见代码注释。

代码:

#include <iostream>
#include <map>
#include <cmath>

using namespace std;

typedef long long ll;

typedef struct {
    ll val;
    bool flag;
} Node;

int phi[37];
map<ll, int> mp;

inline Node new_node(ll val, bool flag){
    Node ans;
    ans.val = val;
    ans.flag = flag;
    return ans;
}

inline Node quick_pow(ll x, ll p, ll mod){
    Node ans;
    ans.val = 1;
    ans.flag = false;
    while (p){
        if (p & 1){
            ans.val *= x;
            if (ans.val >= mod){
                ans.val %= mod;
                ans.flag = true;
            }
        }
        p >>= 1;
        if (p == 0) break;
        x *= x;
        if (x >= mod){
            x %= mod;
            ans.flag = true;
        }
    }
    return ans;
}

inline int bsgs(int a, int b, int p){
    if (p == 1) return 0;
    a %= p;
    b %= p;
    if (b == 1) return 0;
    int m = ceil(sqrt(p)), i = 0;
    ll t = quick_pow(a, m, p).val;
    mp.clear();
    for (register ll j = b; i < m; i++, j = j * a % p){
        mp[j] = i;
    }
    i = 1;
    for (register ll j = t; i <= m; i++, j = j * t % p){
        if (mp.count(j)) return i * m - mp[j];
    }
    return -1;
}

inline int euler(int n){
    int ans = n;
    for (register int i = 2; i * i <= n; i++){
        if (n % i == 0){
            ans -= ans / i;
            while (n % i == 0){
                n /= i;
            }
        }
    }
    if (n > 1) ans -= ans / n;
    return ans;
}

Node tetration(int a, int n, int index){
    if (phi[index] == 1) return new_node(0, true);
    if (n == 0) return new_node(1, false);
    int next_index = index + 1;
    Node x = tetration(a, n - 1, next_index);
    if (x.flag) x.val += phi[next_index];
    return quick_pow(a, x.val, phi[index]);
}

inline int f(int a){
    if (a == 1) return 0;
    for (register int i = 0; ; i++){
        if (tetration(a, i, 0).flag) return i + 1;
    }
}

inline ll quick_pow_with_max_val(ll x, ll p, ll max_val){
    ll ans = 1;
    while (p){
        if (x >= max_val){
            return max_val;
        }
        if (p & 1){
            ans *= x;
            if (ans >= max_val){
                return max_val;
            }
        }
        x *= x;
        p >>= 1;
    }
    return ans;
}

ll tetration_with_max_val(ll a, ll n, int max_val){
    if (n == 0 || a == 1) return 1;
    ll x = tetration_with_max_val(a, n - 1, max_val);
    if (x == max_val) return max_val;
    return quick_pow_with_max_val(a, x, max_val);
}

ll pentation_with_max_val(ll a, ll n, int max_val){
    if (n == 0 || a == 1) return 1;
    ll x = pentation_with_max_val(a, n - 1, max_val);
    if (x == max_val) return max_val;
    return tetration_with_max_val(a, x, max_val);
}

inline ll pentation(ll a, ll n, int max_val, ll mod){
    if (mod == 1) return 0;
    if (n == 0) return 1;
    return tetration(a, pentation_with_max_val(a, n - 1, max_val), 0).val % mod;
}

int main(){
    int t;
    cin >> t;
    for (register int i = 1; i <= t; i++){
        int a, n, b, p;
        cin >> a >> n >> b >> p;
        if (n == 1){
            cout << bsgs(a, b, p) << endl;
        } else {
            int maxx_2 = 0, ans = -1;
            for (register int j = p; j != 1; j = euler(j)){
                phi[maxx_2++] = j;
            }
            phi[maxx_2] = 1;
            if (n == 2){
                for (register int j = 0; j <= maxx_2; j++){
                    if (tetration(a, j, 0).val == b){
                        ans = j;
                        break;
                    }
                }
            } else {
                int maxx_3 = f(a);
                for (register int j = 0; j <= maxx_3; j++){
                    if (pentation(a, j, maxx_2, p) == b){
                        ans = j;
                        break;
                    }
                }
            }
            cout << ans << endl;
        }
    }
    return 0;
}