Min_25系列第1篇/题解P5282 【模板】快速阶乘算法
Ruiqun2009
2022-10-26 13:46:03
~~相信很多人都想到了 $O(n)$ 的解法:连着乘。此做法可以查看 @bh1234666 的博客获得更多信息。~~
观察题目数据可以发现:本题需要一个低于 $O(n)$ 的解法。
我们先设 $s=\lfloor\sqrt n\lfloor$ 且 $f(x)=\prod_{i=1}^s{x+i}$,那么
$$
n!\equiv (\prod_{i=0}^{s-1}f(is))\times\prod_{i=s^2+1}^{n}i\pmod p
$$
其中 $\prod_{i=s^2+1}^{n}i$ 可在 $O(\sqrt n)$ 内完成。我们希望快速计算 $\prod_{i=0}^{s-1}f(is)$。
## $O(\sqrt n\log^2 n)$ 的做法
我们可以使用多项式连续点值平移在 $O(s\log s)$ 得到 $f(x)$ 的各项系数,但是通过多项式多点求值得到 $f(0),f(s),f(2s),\cdots,f(s^2-s)$ 需要 $O(s\log^2 s)$ 的时间复杂度。于是整体时间复杂度为 $O(\sqrt n\log^2 n)$。
## $O(\sqrt n\log n)$ 的做法
令 $f_j(x)=\prod_{i=1}^j(x+i)$。通过倍增即可求出 $\prod_{i=0}^{s-1}f(is)$。
- 倍增方程1(乘以 $2$)
$$
\begin{aligned}
f_{2j}(x)&=\prod_{i=1}^{2j}(x+i)\\
&=\prod_{i=1}^j(x+i)\times\prod_{i=j+1}^{2j}(x+i)\\
&=\prod_{i=1}^j(x+i)\times\prod_{i=1}^j(x+j+i)\\
&=f_j(x)\times f_j(j+x)
\end{aligned}
$$
现在看怎么计算。设 $g(x)=f_j(sx)$,则:
1. 现有的点值是 $g(0),g(1),\cdots,g(j)$(即 $f_j(0),f_j(s),\cdots,f_j(js)$)。
2. 通过多项式连续点值平移得到 $g(j+1),g(j+2),\cdots,g(2j+1)$(这里由于要求区间不重叠,会多平移出一个 $g(2j+1)$,否则会因为除以零寄掉)。
3. 将其拼接后得到 $g(0),g(1),\cdots,g(2j)$(即 $f_j(0),f_j(0+s),\cdots,f_j(0+2js)$)。
4. 再次经过平移后得到 $g(0+\dfrac{j}{s}),g(1+\dfrac{j}{s}),\cdots,g(2j+\dfrac{j}{s})$(即 $f_j(j),f_j(j+s),\cdots,f_j(j+2js)$)。
5. 乘起来即得 $f_{2j}(0),f_{2j}(s),\cdots,f_{2j}(2js)$。
- 倍增方程2(加上 $1$)
直接暴力点乘就可以了。
$$
\begin{aligned}
f_{j+1}(x)&=\prod_{i=1}^{j+1}(x+i)\\
&=\prod_{i=1}^j(x+i)\times(x+j+1)\\
&=f_j(x)\times(x+j+1)
\end{aligned}
$$
然后就可以倍增了。
我们只需要计算 $s$ 个点值,所以时间复杂度为 $O(s\log s)=O(\sqrt n\log n)$。
上代码:
```cpp
// dill
inline int solve(int n, int p) {
static vector<int> f, fd, st;
st.resize(log2(n) + 5);
f.resize(2);
int top = 0;
int s = n;
while (n) {
st[++top] = n;
n >>= 1;
}
n = st[top--];
f[0] = 1;
f[1] = s + 1;
while (top--) {
f.resize(f.size() << 1);
LagrangeInterpolation_ex(n, n + 1, p, f, fd);
std::copy(fd.begin(), fd.end(), f.begin() + n + 1);
f[n << 1 | 1] = 0;
int tmp = 1ll * n * fpow(s, p - 2, p) % p;
n <<= 1;
LagrangeInterpolation_ex(n, tmp, p, f, fd);
for (int i = 0; i <= n; i++) {
f[i] = 1ll * f[i] * fd[i] % p;
}
if (!(st[top + 1] & 1)) continue;
for (int i = 0; i <= n; i++) {
f[i] = 1ll * f[i] * (1ll * s * i % p + n + 1) % p;
}
f[++n] = 1;
for (int i = 1; i <= n; i++) {
f[n] = 1ll * f[n] * (1ll * s * n % p + i) % p;
}
}
int res = f[0];
for (int i = 1; i < s; i++) {
res = 1ll * res * f[i] % p;
}
return res;
}
inline int factorial(int n, int p) {
int tn = p - 1 - n;
int res = 0;
if (tn <= n) {
res = fpow(factorial(tn, p), p - 2, p);
return (tn & 1) ? res : p - res;
}
int k = sqrt(n);
res = solve(k, p);
for (int i = k * k + 1; i <= n; i++) {
res = 1ll * res * i % p;
}
return res % p;
}
signed main() {
int T, n, p;
cin >> T;
while (T--) {
cin >> n >> p;
cout << factorial(n, p) << endl;
}
}
```
习题:
1. [阶乘模大质数](https://loj.ac/p/170)
2. [移数字](http://www.51nod.com/Challenge/Problem.html#problemId=1387)