题解 P5282 【【模板】快速阶乘算法】
shadowice1984 · · 题解
这篇题解算是zzq的blog的详细解释版吧
如果看到这篇题解,请大喊三声"Min_25 牛逼!"
前置芝士:任意模数fft
确保你写的不是三模ntt和未经过优化的拆系数fft这种做一次多项式乘法需要9次或者7次fft的丢人东西
9120年了,三次变两次优化早应该普及了
前置芝士:拉格朗日插值公式
假设我们有一个不超过
本题题解
O(\sqrt{n}log^2n) 的解法
让我们来看看这道题要求我们做什么:计算
那么一个简单粗暴的思路是分块,我们希望计算
那具体来讲我构造这样的一个多项式
如果我们能求出
使用多项式多点求值就可以在
O(\sqrt{n}logn) 的做法
上面的做法可以通过玄学调整块长来优化常数,但是一个log做法只能求出
我们发现我们只需要求出f这个多项式的一些点值就可以生成答案了,那么我们可以尝试不求出
为了让我们能够转移,我们将f的定义从一维的情况拓展到二维,我们定义这样一个多项式出来
那么我们最后要求的是
一个显而易见的性质是,已知多项式f的
现在我们尝试在d这一位上做倍增
具体来讲我们需要实现这两件事情
已知
求出
通过这个操作我们可以把d乘2
已知
求出
通过这个操作我们可以把
然后有了这两个操作之后使用一个类似于快速幂的做法,我们迭代log轮就可以把求出
那么我们考虑如何实现这两个过程
将d乘2
我们知道
因此我们只需要求出来
和
就能计算出我们想要的点值序列的
那么我们不妨构造一个新的多项式
那么我们需要解决这样一个问题
已知
希望求出
求出上面的东西之后我们就已知了
只要求出
我们就可以把倍增需要的两个序列求出来了
所以总结一下就是已知
希望求解
也就是说我们希望从点值转移到点值,那么我们学过的算法里面只有拉格朗日插值是一个完全不涉及多项式的系数还能算多项式的点值的算法
因此我们尝试使用拉格朗日插值算法
我们将
这里说明一下为什么
然后考虑
如果我们设一些这样的函数出来
那么我们就可以得到这样的式子
这东西当然是个标准卷积式子,把
知道了
(这里直接求逆元是
这样我们就可以成功的求出倍增需要的两个数组,于是就可以实现把
迭代一次是
将d加1
现在已知
希望求出
显然
对于剩下的项,我们可以用这个式子计算
显然迭代一次是
完成了这两个操作之后我们就可以在
然后我们乘一乘就能把阶乘算出来啦~
时间复杂度
上代码~
// luogu-judger-enable-o2
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;const int N=262144+10;typedef unsigned long long ll;
const int P=65536;const int SF=16;const int msk=65535;ll mod;ll PP;
typedef long double ld;const ld pi=acos(-1.0);
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
struct cmp
{
ld r;ld v;
friend cmp operator +(cmp a,cmp b){return (cmp){a.r+b.r,a.v+b.v};}
friend cmp operator -(cmp a,cmp b){return (cmp){a.r-b.r,a.v-b.v};}
friend cmp operator *(cmp a,cmp b){return (cmp){a.r*b.r-a.v*b.v,a.r*b.v+a.v*b.r};}
void operator /=(const int& len){r/=len;v/=len;}
}rt[2][22][N],tr[N],tr1[N],tr2[N],tr3[N],tr4[N],tr5[N],tr6[N];
int rv[22][N];ll m13[N],m14[N],m23[N],m24[N];
inline void pre()
{
for(int d=1;d<=18;d++)
for(int i=1;i<(1<<d);i++)rv[d][i]=(rv[d][i>>1]>>1)|((i&1)<<(d-1));
for(int d=1,t=1;d<=18;d++,t<<=1)
for(int i=0;i<(1<<d);i++)rt[0][d][i]=(cmp){cos(pi*i/t),sin(pi*i/t)};
for(int d=1,t=1;d<=18;d++,t<<=1)
for(int i=0;i<(1<<d);i++)rt[1][d][i]=(cmp){cos(pi*i/t),-sin(pi*i/t)};
}inline void fft(cmp* a,int len,int d,int o)
{
for(int i=1;i<len;i++)if(i<rv[d][i])swap(a[i],a[rv[d][i]]);cmp* w;int i;
for(int k=1,j=1;k<len;k<<=1,j++)
for(int s=0;s<len;s+=(k<<1))
for(i=s,w=rt[o][j];i<s+k;i++,++w)
{cmp a1=a[i+k]*(*w);a[i+k]=a[i]-a1;a[i]=a[i]+a1;}
if(o)for(int i=0;i<len;i++)a[i]/=len;
}inline void dbdft(ll* a,int len,int d,cmp* op1,cmp* op2)
{
for(int i=0;i<len;i++)tr[i]=(cmp){(ld)(a[i]>>SF),(ld)(a[i]&msk)};
fft(tr,len,d,0);tr[len]=tr[0];
for(cmp* p1=tr,*p2=tr+len,*p3=op1;p1!=tr+len;++p1,--p2,++p3)
(*p3)=(cmp){p1->r+p2->r,p1->v-p2->v}*(cmp){0.5,0};
for(cmp* p1=tr,*p2=tr+len,*p3=op2;p1!=tr+len;++p1,--p2,++p3)
(*p3)=(cmp){p1->r-p2->r,p1->v+p2->v}*(cmp){0,-0.5};
}inline void dbidft(cmp* tr,int len,int d,ll* a,ll* b)
{
fft(tr,len,d,1);
for(int i=0;i<len;i++)a[i]=(ll)(tr[i].r+0.5)%mod;
for(int i=0;i<len;i++)b[i]=(ll)(tr[i].v+0.5)%mod;
}inline void poly_mul(ll* a,ll* b,ll* c,int len,int d)//以上都是任意模数fft的板子
{
dbdft(a,len,d,tr1,tr2);dbdft(b,len,d,tr3,tr4);
for(int i=0;i<len;i++)tr5[i]=tr1[i]*tr3[i]+(cmp){0,1}*tr2[i]*tr4[i];
for(int i=0;i<len;i++)tr6[i]=tr2[i]*tr3[i]+(cmp){0,1}*tr1[i]*tr4[i];
dbidft(tr5,len,d,m13,m24);dbidft(tr6,len,d,m23,m14);
for(int i=0;i<len;i++)c[i]=m13[i]*PP%mod;
for(int i=0;i<len;i++)(c[i]+=(m23[i]+m14[i])*P+m24[i])%=mod;
}namespace iter
{
ll f[N];ll g[N];ll h[N];ll ifac[N];
inline void ih()
{
ifac[0]=ifac[1]=1;
for(int i=2;i<min((ll)N,mod);i++)ifac[i]=(mod-mod/i)*ifac[mod%i]%mod;
for(int i=1;i<min((ll)N,mod);i++)(ifac[i]*=ifac[i-1])%=mod;
}inline void calch(ll del,int cur,ll* ip,ll* op)
{
int d=0;int len=1;while(len<=cur+cur+cur)len<<=1,d++;
for(int i=0;i<=cur;i++)f[i]=ip[i]*ifac[i]%mod*ifac[cur-i]%mod;
for(int i=cur-1;i>=0;i-=2)f[i]=(mod-f[i])%mod;
for(int i=0;i<=cur+cur;i++)g[i]=po((del+mod-cur+i)%mod,mod-2);
for(int i=cur+1;i<len;i++)f[i]=0;for(int i=cur+cur+1;i<len;i++)g[i]=0;
poly_mul(f,g,h,len,d);//卷积求出h'
ll xs=1;ll p1=del-cur;ll p2=del;
for(int i=p1;i<=p2;i++)(xs*=i)%=mod;
for(int i=0;i<=cur;i++,p1++,p2++)//双指针求出系数
{
op[i]=h[i+cur]*xs%mod;
(xs*=po(p1,mod-2))%=mod,(xs*=(p2+1))%=mod;
}
}
}ll val[N];ll fv1[N];ll fv2[N];
inline void solve(int n)//倍增
{
int hb=0;for(int p=n;p;p>>=1)hb++;val[0]=1;
for(int z=hb,cur=0;z>=0;z--)
{
if(cur!=0)//把d乘2
{
iter::calch(cur+1,cur,val,fv1);
for(int i=0;i<=cur;i++)val[cur+i+1]=fv1[i];val[cur<<1|1]=0;
iter::calch(cur*po(n,mod-2)%mod,cur<<1,val,fv2);
cur<<=1;for(int i=0;i<=cur;i++)(val[i]*=fv2[i])%=mod;
}if((n>>z)&1)//把d加1
{
for(int i=0;i<=cur;i++)(val[i]*=(ll)(n*i)+cur+1)%=mod;cur|=1;val[cur]=1;
for(int i=1;i<=cur;i++)(val[cur]*=(ll)cur*n+i)%=mod;
}
}
}
int main()
{
pre();int n;scanf("%d%lld",&n,&mod);iter::ih();
int bl=sqrt(n);PP=(ll)P*P%mod;solve(bl);ll res=1;
for(int i=0,id=0;;i+=bl,id++)//分块
{
if((ll)i+bl>n){for(int j=i+1;j<=n;j++)(res*=j)%=mod;break;}
(res*=val[id])%=mod;
}printf("%lld",res);return 0;//拜拜程序~
}