Min_25系列第1篇/题解P5282 【模板】快速阶乘算法

Ruiqun2009

2022-10-26 13:46:03

Solution

~~相信很多人都想到了 $O(n)$ 的解法:连着乘。此做法可以查看 @bh1234666 的博客获得更多信息。~~ 观察题目数据可以发现:本题需要一个低于 $O(n)$ 的解法。 我们先设 $s=\lfloor\sqrt n\lfloor$ 且 $f(x)=\prod_{i=1}^s{x+i}$,那么 $$ n!\equiv (\prod_{i=0}^{s-1}f(is))\times\prod_{i=s^2+1}^{n}i\pmod p $$ 其中 $\prod_{i=s^2+1}^{n}i$ 可在 $O(\sqrt n)$ 内完成。我们希望快速计算 $\prod_{i=0}^{s-1}f(is)$。 ## $O(\sqrt n\log^2 n)$ 的做法 我们可以使用多项式连续点值平移在 $O(s\log s)$ 得到 $f(x)$ 的各项系数,但是通过多项式多点求值得到 $f(0),f(s),f(2s),\cdots,f(s^2-s)$ 需要 $O(s\log^2 s)$ 的时间复杂度。于是整体时间复杂度为 $O(\sqrt n\log^2 n)$。 ## $O(\sqrt n\log n)$ 的做法 令 $f_j(x)=\prod_{i=1}^j(x+i)$。通过倍增即可求出 $\prod_{i=0}^{s-1}f(is)$。 - 倍增方程1(乘以 $2$) $$ \begin{aligned} f_{2j}(x)&=\prod_{i=1}^{2j}(x+i)\\ &=\prod_{i=1}^j(x+i)\times\prod_{i=j+1}^{2j}(x+i)\\ &=\prod_{i=1}^j(x+i)\times\prod_{i=1}^j(x+j+i)\\ &=f_j(x)\times f_j(j+x) \end{aligned} $$ 现在看怎么计算。设 $g(x)=f_j(sx)$,则: 1. 现有的点值是 $g(0),g(1),\cdots,g(j)$(即 $f_j(0),f_j(s),\cdots,f_j(js)$)。 2. 通过多项式连续点值平移得到 $g(j+1),g(j+2),\cdots,g(2j+1)$(这里由于要求区间不重叠,会多平移出一个 $g(2j+1)$,否则会因为除以零寄掉)。 3. 将其拼接后得到 $g(0),g(1),\cdots,g(2j)$(即 $f_j(0),f_j(0+s),\cdots,f_j(0+2js)$)。 4. 再次经过平移后得到 $g(0+\dfrac{j}{s}),g(1+\dfrac{j}{s}),\cdots,g(2j+\dfrac{j}{s})$(即 $f_j(j),f_j(j+s),\cdots,f_j(j+2js)$)。 5. 乘起来即得 $f_{2j}(0),f_{2j}(s),\cdots,f_{2j}(2js)$。 - 倍增方程2(加上 $1$) 直接暴力点乘就可以了。 $$ \begin{aligned} f_{j+1}(x)&=\prod_{i=1}^{j+1}(x+i)\\ &=\prod_{i=1}^j(x+i)\times(x+j+1)\\ &=f_j(x)\times(x+j+1) \end{aligned} $$ 然后就可以倍增了。 我们只需要计算 $s$ 个点值,所以时间复杂度为 $O(s\log s)=O(\sqrt n\log n)$。 上代码: ```cpp // dill inline int solve(int n, int p) { static vector<int> f, fd, st; st.resize(log2(n) + 5); f.resize(2); int top = 0; int s = n; while (n) { st[++top] = n; n >>= 1; } n = st[top--]; f[0] = 1; f[1] = s + 1; while (top--) { f.resize(f.size() << 1); LagrangeInterpolation_ex(n, n + 1, p, f, fd); std::copy(fd.begin(), fd.end(), f.begin() + n + 1); f[n << 1 | 1] = 0; int tmp = 1ll * n * fpow(s, p - 2, p) % p; n <<= 1; LagrangeInterpolation_ex(n, tmp, p, f, fd); for (int i = 0; i <= n; i++) { f[i] = 1ll * f[i] * fd[i] % p; } if (!(st[top + 1] & 1)) continue; for (int i = 0; i <= n; i++) { f[i] = 1ll * f[i] * (1ll * s * i % p + n + 1) % p; } f[++n] = 1; for (int i = 1; i <= n; i++) { f[n] = 1ll * f[n] * (1ll * s * n % p + i) % p; } } int res = f[0]; for (int i = 1; i < s; i++) { res = 1ll * res * f[i] % p; } return res; } inline int factorial(int n, int p) { int tn = p - 1 - n; int res = 0; if (tn <= n) { res = fpow(factorial(tn, p), p - 2, p); return (tn & 1) ? res : p - res; } int k = sqrt(n); res = solve(k, p); for (int i = k * k + 1; i <= n; i++) { res = 1ll * res * i % p; } return res % p; } signed main() { int T, n, p; cin >> T; while (T--) { cin >> n >> p; cout << factorial(n, p) << endl; } } ``` 习题: 1. [阶乘模大质数](https://loj.ac/p/170) 2. [移数字](http://www.51nod.com/Challenge/Problem.html#problemId=1387)