题解:P11175 【模板】基于值域预处理的快速离散对数

· · 题解

前段时间做区域赛题碰到几个可以用离散对数转换成卷积的题,发现需要对 1\sim n 求出离散对数。发现还不会这个科技,所以学习一下。

思路

\log n 表示 n 在模 P 意义下以 g 为原根的离散对数。

先解决求出 1\sim n 内每个数离散对数值的问题。

容易发现 \log ab=\log a+\log b。假如我们要求出 \log i,\ i=1,2,\cdots,n,可以利用筛法,这样只需要求出 \pi(n) 个素数处的离散对数值即可,其中 \pi(n) 表示 n 以内素数个数。

求解一个数的离散对数,通常使用大步小步算法(BSGS)。记 x=\log y=Bi+j \in [0, P),其中 0\le j<B,0\le i\le \lfloor (P-1)/B\rfloor。那么有:

\begin{aligned} g^{Bi+j} &\equiv y &\pmod P \\ g^{j} &\equiv y\times g^{-Bi} &\pmod P \\ \end{aligned}

所以我们枚举 j=0,1,\cdots,B-1,把 g^j 插入哈希表里,接着枚举 i=0,1,\cdots,查询哈希表里有没有对应的 j 即可。复杂度为 \mathcal O(B+P/B),通常取 B=\sqrt P 使得总复杂度为 \mathcal O(\sqrt P)

然而我们这次一共要对 \pi(n) 个数字求离散对数。事实上,我们向哈希表里插入元素的复杂度为 \mathcal O(B),查询的复杂度为 \mathcal O(\pi(n)P/B),其实取 B=\sqrt{P\pi(n)} 才取得最优复杂度。这一步值得注意,有些复杂度错误的批量求解离散对数的做法就是块长取错了。

现在我们要多次询问 y\in [1, P) 的离散对数。有一个小技巧:

先预处理出 \sqrt{P} + 1 内所有数的离散对数。对于每次询问,若 y\le \sqrt P + 1 则直接回答,否则设 P=vy+r,其中 0\le r< y0\le v \le \sqrt P

根据 P=vy+r 可知 y=(P-r)/v,从而得到:

\log y \equiv \log(-r)-\log v\equiv \log(P-1)+\log r-\log v

根据 P=(v+1)y+r-y 可知 y=(P+y-r)/(v+1),从而得到:

\log y\equiv \log(y-r)-\log(v+1)

其中 \log v 或是 \log (v+1) 都可直接查表获得。由于 \min(r, y-r)\le y/2,所以每次迭代都能使得 y 的规模减半,直到 y\le \sqrt{P}+1 查表回答。

于是我们得到了一个 \mathcal O(P^{0.75}/ \log^{0.5} P) 复杂度预处理,\mathcal O(\log P) 回答单次询问的做法。

参考代码

#include<bits/stdc++.h>
using namespace std;

int power(int a, int b, int  p){
    int r = 1;
    while(b){
        if(b & 1) r = 1ll * r * a % p;
        b >>= 1,  a = 1ll * a * a % p;
    }
    return r;
}
namespace BSGS {
    unordered_map <int, int> M;
    int B, U, P, g;
    void init(int g, int P0, int B0){
        M.clear();
        B = B0;
        P = P0;
        U = power(power(g, B, P), P - 2, P);
        int w = 1;
        for(int i = 0;i < B;++ i){
            M[w] = i;
            w = 1ll * w * g % P;
        }
    }
    int solve(int y){
        int w = y;
        for(int i = 0;i <= P / B;++ i){
            if(M.count(w)){
                return i * B + M[w];
            }
            w = 1ll * w * U % P;
        }
        return -1;
    }
}

const int MAXN = 1e7 + 3;
int H[MAXN], P[MAXN], H0, p, h, g, mod;
bool V[MAXN];

int solve(int x){
    if(x <= h){
        return H[x];
    }
    int v = mod / x, r = mod % x;
    if(r < x - r){
        return ((H0 + solve(r)) % (mod - 1) - H[v] + mod - 1) % (mod - 1);
    } else {
        return (solve(x - r) - H[v + 1] + mod - 1) % (mod - 1);
    }
}

int main(){
    ios :: sync_with_stdio(false);
    cin.tie(nullptr);

    int T;
    cin >> mod >> g;
    h = sqrt(mod) + 1;

    BSGS :: init(g, mod, sqrt(1ll * mod * sqrt(mod) / log(mod)));
    H0 = BSGS :: solve(mod - 1);

    H[1] = 0;
    for(int i = 2;i <= h;++ i){
        if(!V[i]){
            P[++ p] = i;
            H[i] = BSGS :: solve(i);
        }
        for(int j = 1;j <= p && P[j] <= h / i;++ j){
            int &p = P[j];
            H[i * p] = (H[i] + H[p]) % (mod - 1);
            V[i * p] = true;
            if(i % p == 0)
                break;
        }
    }

    cin >> T;
    while(T --){
        int x, tmp = 0;
        cin >> x;
        cout << solve(x) << "\n";
    }
    return 0;
}

Bonus:对于 NTT 模数存在更加优秀的做法,但我太懒了还没学,欢迎大家在题解区补充。