Min_25系列第1篇/题解P5282 【模板】快速阶乘算法

· · 题解

相信很多人都想到了 O(n) 的解法:连着乘。此做法可以查看 @bh1234666 的博客获得更多信息。

观察题目数据可以发现:本题需要一个低于 O(n) 的解法。

我们先设 s=\lfloor\sqrt n\rfloorf(x)=\prod_{i=1}^s{x+i},那么

n!\equiv (\prod_{i=0}^{s-1}f(is))\times\prod_{i=s^2+1}^{n}i\pmod p

其中 \prod_{i=s^2+1}^{n}i 可在 O(\sqrt n) 内完成。我们希望快速计算 \prod_{i=0}^{s-1}f(is)

O(\sqrt n\log^2 n) 的做法

我们可以使用多项式连续点值平移在 O(s\log s) 得到 f(x) 的各项系数,但是通过多项式多点求值得到 f(0),f(s),f(2s),\cdots,f(s^2-s) 需要 O(s\log^2 s) 的时间复杂度。于是整体时间复杂度为 O(\sqrt n\log^2 n)

O(\sqrt n\log n) 的做法

f_j(x)=\prod_{i=1}^j(x+i)。通过倍增即可求出 \prod_{i=0}^{s-1}f(is)

\begin{aligned} f_{2j}(x)&=\prod_{i=1}^{2j}(x+i)\\ &=\prod_{i=1}^j(x+i)\times\prod_{i=j+1}^{2j}(x+i)\\ &=\prod_{i=1}^j(x+i)\times\prod_{i=1}^j(x+j+i)\\ &=f_j(x)\times f_j(j+x) \end{aligned}

现在看怎么计算。设 g(x)=f_j(sx),则:

  1. 现有的点值是 g(0),g(1),\cdots,g(j)(即 f_j(0),f_j(s),\cdots,f_j(js))。
  2. 通过多项式连续点值平移得到 g(j+1),g(j+2),\cdots,g(2j+1)(这里由于要求区间不重叠,会多平移出一个 g(2j+1),否则会因为除以零寄掉)。
  3. 将其拼接后得到 g(0),g(1),\cdots,g(2j)(即 f_j(0),f_j(0+s),\cdots,f_j(0+2js))。
  4. 再次经过平移后得到 g(0+\dfrac{j}{s}),g(1+\dfrac{j}{s}),\cdots,g(2j+\dfrac{j}{s})(即 f_j(j),f_j(j+s),\cdots,f_j(j+2js))。
  5. 乘起来即得 f_{2j}(0),f_{2j}(s),\cdots,f_{2j}(2js)

直接暴力点乘就可以了。

\begin{aligned} f_{j+1}(x)&=\prod_{i=1}^{j+1}(x+i)\\ &=\prod_{i=1}^j(x+i)\times(x+j+1)\\ &=f_j(x)\times(x+j+1) \end{aligned}

然后就可以倍增了。

我们只需要计算 s 个点值,所以时间复杂度为 O(s\log s)=O(\sqrt n\log n)

上代码:

// dill
inline int solve(int n, int p) {
    static vector<int> f, fd, st;
    st.resize(log2(n) + 5);
    f.resize(2);
    int top = 0;
    int s = n;
    while (n) {
        st[++top] = n;
        n >>= 1;
    }
    n = st[top--];
    f[0] = 1;
    f[1] = s + 1;
    while (top--) {
        f.resize(f.size() << 1);
        LagrangeInterpolation_ex(n, n + 1, p, f, fd);
        std::copy(fd.begin(), fd.end(), f.begin() + n + 1);
        f[n << 1 | 1] = 0;
        int tmp = 1ll * n * fpow(s, p - 2, p) % p;
        n <<= 1;
        LagrangeInterpolation_ex(n, tmp, p, f, fd);
        for (int i = 0; i <= n; i++) {
            f[i] = 1ll * f[i] * fd[i] % p;
        }
        if (!(st[top + 1] & 1)) continue;
        for (int i = 0; i <= n; i++) {
            f[i] = 1ll * f[i] * (1ll * s * i % p + n + 1) % p;
        }
        f[++n] = 1;
        for (int i = 1; i <= n; i++) {
            f[n] = 1ll * f[n] * (1ll * s * n % p + i) % p;
        }
    }
    int res = f[0];
    for (int i = 1; i < s; i++) {
        res = 1ll * res * f[i] % p;
    }
    return res;
}
inline int factorial(int n, int p) {
    int tn = p - 1 - n;
    int res = 0;
    if (tn <= n) {
        res = fpow(factorial(tn, p), p - 2, p);
        return (tn & 1) ? res : p - res;
    }
    int k = sqrt(n);
    res = solve(k, p);
    for (int i = k * k + 1; i <= n; i++) {
        res = 1ll * res * i % p;
    }
    return res % p;
}
signed main() {
    int T, n, p;
    cin >> T;
    while (T--) {
        cin >> n >> p;
        cout << factorial(n, p) << endl;
    }
}

习题:

  1. 阶乘模大质数
  2. 移数字