题解 P5282 【【模板】快速阶乘算法】
shadowice1984
2019-04-04 07:54:37
> 这篇题解算是zzq的blog的详细解释版吧
>如果看到这篇题解,请大喊三声"Min_25 牛逼!"
___________________________
## 前置芝士:任意模数fft
确保你写的不是三模ntt和未经过优化的拆系数fft这种做一次多项式乘法需要9次或者7次fft的丢人东西
~~9120年了,三次变两次优化早应该普及了~~
## 前置芝士:拉格朗日插值公式
假设我们有一个不超过$n$次的多项式$f$并且我们知道了这个多项式的m个点值对$(x_{i},f(x_{i}))$,并且$m \geq n+1$那么我们有如下公式
$$f(x)=\sum_{i=1}^{m}f(x_{i})\prod_{j \neq i} \frac{x-x_{j}}{x_{i}-x_{j}}$$
______________
# 本题题解
## $O(\sqrt{n}log^2n)$的解法
让我们来看看这道题要求我们做什么:计算$n! \mod P$
那么一个简单粗暴的思路是分块,我们希望计算$1 \dots n$的乘积,我们将序列分成长度为$B$的若干个块,如果我们能快速求出每个块的积,我们就能凑出n阶乘来
那具体来讲我构造这样的一个多项式$f(x)$
$$f(x)=\prod_{i=1}^{B}(x+i)$$
如果我们能求出$f(0),f(B),f(2B) \dots f(\lfloor \frac{n}{B} \rfloor B)$的话,我们就能在$O(\lfloor \frac{n}{B}\rfloor+B)$的时间内计算出答案了
使用多项式多点求值就可以在$O(\sqrt{N}log^2n)$的时间内计算出答案来
## $O(\sqrt{n}logn)$的做法
上面的做法可以通过玄学调整块长来优化常数,但是一个log做法只能求出$f(0),f(B) \dots f(B^2)$的值,因此我们需要把块长设为$\sqrt{n}$
我们发现我们只需要求出f这个多项式的一些点值就可以生成答案了,那么我们可以尝试不求出$f$的系数表达,而是直接使用f的点值进行迭代
为了让我们能够转移,我们将f的定义从一维的情况拓展到二维,我们定义这样一个多项式出来
$$f(d,x)=\prod_{i=1}^{d}(x+i)$$
那么我们最后要求的是$f(B,0),f(B,B) \dots f(B,B^2)$这些点值
一个显而易见的性质是,已知多项式f的$n+1$对点值$(x,f(x))$就可以唯一确定这个多项式,因此如果我们求出了$f(d,0),f(d,B) \dots f(d,dB)$我们就可以唯一的确定多项式$f(d,x)$了
现在我们尝试在d这一位上做倍增
具体来讲我们需要实现这两件事情
已知
$$f(d,0),f(d,B) \dots f(d,dB)$$
求出
$$f(2d,0),f(2d,B) \dots f(2d,2dB)$$
通过这个操作我们可以把d乘2
已知
$$f(d,0),f(d,B) \dots f(d,dB)$$
求出
$$f(d+1,0),f(d+1,B) \dots f(d+1,(d+1)B)$$
通过这个操作我们可以把$d$加$1$
然后有了这两个操作之后使用一个类似于快速幂的做法,我们迭代log轮就可以把求出$f(B,0),f(B,B) \dots f(B,B^2)$了
那么我们考虑如何实现这两个过程
## 将d乘2
我们知道$f(2d,x)=f(d,x)f(d,dx)$
因此我们只需要求出来
$$f(d,0),f(d,B) \dots f(d,2dB)$$
和
$$f(d,d),f(d,d+B) \dots f(d,2dB+d)$$
就能计算出我们想要的点值序列的
那么我们不妨构造一个新的多项式$h(x)=f(d,Bx)$
那么我们需要解决这样一个问题
已知
$$h(0),h(1),h(2) \dots h(d)$$
希望求出
$$h(d+1+0),h(d+1+1),h(d+1+1) \dots h(d+d+1)$$
求出上面的东西之后我们就已知了
$$h(0),h(1),h(2)\dots h(2d)$$
只要求出
$$h(d/B+0),h(d/B+1),h(d/B+2) \dots h(d/B+2d)$$
我们就可以把倍增需要的两个序列求出来了
所以总结一下就是已知
$$h(0),h(1) \dots h(k)$$
希望求解
$$h(\Delta+0),h(\Delta+1) \dots h(\Delta+k)$$
也就是说我们希望从点值转移到点值,那么我们学过的算法里面只有拉格朗日插值是一个完全不涉及多项式的系数还能算多项式的点值的算法
因此我们尝试使用拉格朗日插值算法
$$h(\Delta+n)=\sum_{i=0}^{k}h(i)\prod_{j \neq i}\frac{\Delta+n-j}{i-j}$$
我们将$\prod$里的分子提出来就是
$$h(\Delta+n)=\prod_{j=0}^{k}(\Delta+n-j)(\sum_{i=0}^{k}\frac{h(i)}{\Delta+n-i}\prod_{j \neq i}\frac{1}{i-j})$$
这里说明一下为什么$\Delta+n-i$不可能是0,因为$h(x)$本质上代表了一段连续数字的乘积,而根据我们算法的原理,$h(\Delta+n)$和$h(i)$一定代表了两段不同的数字,所以分母不可能为0
然后考虑$\prod_{j \neq }\frac{1}{i-j}$手玩一下会发现他是两段阶乘乘起来的,所以式子可以化成这样
$$h(\Delta+n)=\prod_{j=0}^{k}(\Delta+n-j)(\sum_{i=0}^{k}\frac{1}{\Delta+n-i}×\frac{h(i)}{i!(k-i)!(-1)^{k-i}})$$
如果我们设一些这样的函数出来
$$h'(n)=\frac{h(\Delta-k+n)}{\prod_{j=0}^{k}(\Delta+n-j)}$$
$$f(n)=\frac{h(n)}{n!(k-n)!(-1)^{k-n}}$$
$$g(n)=\frac{1}{\Delta-k+n}$$
那么我们就可以得到这样的式子
$$h'(n)=\sum_{i+j=k,i\geq 0,j\geq 0}f(i)g(j)$$
这东西当然是个标准卷积式子,把$f$和$g$求出来卷一卷就行了
知道了$h'(n+k)$之后我们就可以推出$h(n)$来了,发现$h'(n+k)$和$h(n)$的比值是一段区间当中数字的乘积,这个系数在可以$two-pointer$扫一遍在$O(k+logP)$的时间内求出
(这里直接求逆元是$O(klogP)$的,但是的确存在$O(k+logP)$的算法)
这样我们就可以成功的求出倍增需要的两个数组,于是就可以实现把$d$乘2的操作了
迭代一次是$O(nlogn)$的
## 将d加1
现在已知
$$f(d,0),f(d,B) \dots f(d,dB)$$
希望求出
$$f(d+1,0),f(d+1,B) \dots f(d+1,(d+1)B)$$
显然$f(d+1,(d+1)B)$暴力计算就行了
对于剩下的项,我们可以用这个式子计算
$$f(d+1,x)=f(d,x)(x+d+1)$$
显然迭代一次是$O(n)$的
完成了这两个操作之后我们就可以在$log$轮迭代中把d从0迭代到B了
然后我们乘一乘就能把阶乘算出来啦~
时间复杂度
$$T(n)=T(n/2)+O(nlogn)=O(nlogn)$$
上代码~
```C
// luogu-judger-enable-o2
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;const int N=262144+10;typedef unsigned long long ll;
const int P=65536;const int SF=16;const int msk=65535;ll mod;ll PP;
typedef long double ld;const ld pi=acos(-1.0);
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
struct cmp
{
ld r;ld v;
friend cmp operator +(cmp a,cmp b){return (cmp){a.r+b.r,a.v+b.v};}
friend cmp operator -(cmp a,cmp b){return (cmp){a.r-b.r,a.v-b.v};}
friend cmp operator *(cmp a,cmp b){return (cmp){a.r*b.r-a.v*b.v,a.r*b.v+a.v*b.r};}
void operator /=(const int& len){r/=len;v/=len;}
}rt[2][22][N],tr[N],tr1[N],tr2[N],tr3[N],tr4[N],tr5[N],tr6[N];
int rv[22][N];ll m13[N],m14[N],m23[N],m24[N];
inline void pre()
{
for(int d=1;d<=18;d++)
for(int i=1;i<(1<<d);i++)rv[d][i]=(rv[d][i>>1]>>1)|((i&1)<<(d-1));
for(int d=1,t=1;d<=18;d++,t<<=1)
for(int i=0;i<(1<<d);i++)rt[0][d][i]=(cmp){cos(pi*i/t),sin(pi*i/t)};
for(int d=1,t=1;d<=18;d++,t<<=1)
for(int i=0;i<(1<<d);i++)rt[1][d][i]=(cmp){cos(pi*i/t),-sin(pi*i/t)};
}inline void fft(cmp* a,int len,int d,int o)
{
for(int i=1;i<len;i++)if(i<rv[d][i])swap(a[i],a[rv[d][i]]);cmp* w;int i;
for(int k=1,j=1;k<len;k<<=1,j++)
for(int s=0;s<len;s+=(k<<1))
for(i=s,w=rt[o][j];i<s+k;i++,++w)
{cmp a1=a[i+k]*(*w);a[i+k]=a[i]-a1;a[i]=a[i]+a1;}
if(o)for(int i=0;i<len;i++)a[i]/=len;
}inline void dbdft(ll* a,int len,int d,cmp* op1,cmp* op2)
{
for(int i=0;i<len;i++)tr[i]=(cmp){(ld)(a[i]>>SF),(ld)(a[i]&msk)};
fft(tr,len,d,0);tr[len]=tr[0];
for(cmp* p1=tr,*p2=tr+len,*p3=op1;p1!=tr+len;++p1,--p2,++p3)
(*p3)=(cmp){p1->r+p2->r,p1->v-p2->v}*(cmp){0.5,0};
for(cmp* p1=tr,*p2=tr+len,*p3=op2;p1!=tr+len;++p1,--p2,++p3)
(*p3)=(cmp){p1->r-p2->r,p1->v+p2->v}*(cmp){0,-0.5};
}inline void dbidft(cmp* tr,int len,int d,ll* a,ll* b)
{
fft(tr,len,d,1);
for(int i=0;i<len;i++)a[i]=(ll)(tr[i].r+0.5)%mod;
for(int i=0;i<len;i++)b[i]=(ll)(tr[i].v+0.5)%mod;
}inline void poly_mul(ll* a,ll* b,ll* c,int len,int d)//以上都是任意模数fft的板子
{
dbdft(a,len,d,tr1,tr2);dbdft(b,len,d,tr3,tr4);
for(int i=0;i<len;i++)tr5[i]=tr1[i]*tr3[i]+(cmp){0,1}*tr2[i]*tr4[i];
for(int i=0;i<len;i++)tr6[i]=tr2[i]*tr3[i]+(cmp){0,1}*tr1[i]*tr4[i];
dbidft(tr5,len,d,m13,m24);dbidft(tr6,len,d,m23,m14);
for(int i=0;i<len;i++)c[i]=m13[i]*PP%mod;
for(int i=0;i<len;i++)(c[i]+=(m23[i]+m14[i])*P+m24[i])%=mod;
}namespace iter
{
ll f[N];ll g[N];ll h[N];ll ifac[N];
inline void ih()
{
ifac[0]=ifac[1]=1;
for(int i=2;i<min((ll)N,mod);i++)ifac[i]=(mod-mod/i)*ifac[mod%i]%mod;
for(int i=1;i<min((ll)N,mod);i++)(ifac[i]*=ifac[i-1])%=mod;
}inline void calch(ll del,int cur,ll* ip,ll* op)
{
int d=0;int len=1;while(len<=cur+cur+cur)len<<=1,d++;
for(int i=0;i<=cur;i++)f[i]=ip[i]*ifac[i]%mod*ifac[cur-i]%mod;
for(int i=cur-1;i>=0;i-=2)f[i]=(mod-f[i])%mod;
for(int i=0;i<=cur+cur;i++)g[i]=po((del+mod-cur+i)%mod,mod-2);
for(int i=cur+1;i<len;i++)f[i]=0;for(int i=cur+cur+1;i<len;i++)g[i]=0;
poly_mul(f,g,h,len,d);//卷积求出h'
ll xs=1;ll p1=del-cur;ll p2=del;
for(int i=p1;i<=p2;i++)(xs*=i)%=mod;
for(int i=0;i<=cur;i++,p1++,p2++)//双指针求出系数
{
op[i]=h[i+cur]*xs%mod;
(xs*=po(p1,mod-2))%=mod,(xs*=(p2+1))%=mod;
}
}
}ll val[N];ll fv1[N];ll fv2[N];
inline void solve(int n)//倍增
{
int hb=0;for(int p=n;p;p>>=1)hb++;val[0]=1;
for(int z=hb,cur=0;z>=0;z--)
{
if(cur!=0)//把d乘2
{
iter::calch(cur+1,cur,val,fv1);
for(int i=0;i<=cur;i++)val[cur+i+1]=fv1[i];val[cur<<1|1]=0;
iter::calch(cur*po(n,mod-2)%mod,cur<<1,val,fv2);
cur<<=1;for(int i=0;i<=cur;i++)(val[i]*=fv2[i])%=mod;
}if((n>>z)&1)//把d加1
{
for(int i=0;i<=cur;i++)(val[i]*=(ll)(n*i)+cur+1)%=mod;cur|=1;val[cur]=1;
for(int i=1;i<=cur;i++)(val[cur]*=(ll)cur*n+i)%=mod;
}
}
}
int main()
{
pre();int n;scanf("%d%lld",&n,&mod);iter::ih();
int bl=sqrt(n);PP=(ll)P*P%mod;solve(bl);ll res=1;
for(int i=0,id=0;;i+=bl,id++)//分块
{
if((ll)i+bl>n){for(int j=i+1;j<=n;j++)(res*=j)%=mod;break;}
(res*=val[id])%=mod;
}printf("%lld",res);return 0;//拜拜程序~
}
```