题解:P5282 【模板】快速阶乘算法

· · 题解

Solution

考虑分块,令 m=\lfloor\sqrt n\rfloorf(n,x)=\sum\limits_{i=1}^n (x+i),我们要求的是 f(m,0),f(m,m),\dots,f(m,m^2)

一种暴力的做法是分治求出 f(m,x) 的多项式,然后直接跑多项式多点求值,用脚维护,不需要任何脑子,但是时间复杂度 O(T\sqrt m\log^2 m),这并不优雅。

考虑倍增,并动态维护 f(n,0),\dots,f(n,nm)(因为 f(n,x)n 次多项式,所以考虑维护 n+1 个点值)。

n 变为 n+1 是简单的,考虑如何做 n 变为 2n,首先有 f(2n,xm)=f(n,xm)f(n,xm+n),然后我们需要求出 f(n,0),\dots,f(n,2nm)f(n,n),f(n,2nm+n),注意到要求的这两个东西都是等差数列,且公差与维护的点值公差都为 n,因此考虑把 x 都除以 n,然后跑两次 拉格朗日插值2 即可。

注意不保证模数可以 \text{NTT},因此要跑 任意模数多项式乘法

时间复杂度是 T(m)=T\left(\dfrac m2\right)+O(m\log m)T(m) 显然就为 O(m\log m),因此总时间复杂度是 O(T\sqrt m\log m)

Code

//when you use vector or deque,pay attention to the size of it.
//by OldDirverTree
#include<bits/stdc++.h>
//#include<atcoder/all>
#define P pair<int,int>
#define int long long
#define mid (l+r>>1)
using namespace std;
//using namespace atcoder;
//using mint=modint998244353;
const long double Pi=acos(-1.0);
using poly=vector<int>;
int mod,tot,rev[1<<20];

struct Complex {
    long double x,y;
    Complex operator +(Complex o)const { return {x+o.x,y+o.y}; }
    Complex operator -(Complex o)const { return {x-o.x,y-o.y}; }
    Complex operator *(Complex o)const { return {x*o.x-y*o.y,x*o.y+y*o.x}; }
}A[1<<20],B[1<<20],C[1<<20],D[1<<20];

struct custom_hash
{
    static uint64_t splitmix64(uint64_t x) {
        x+=0x9e3779b97f4a7c15;
        x=(x^(x>>30) )*0xbf58476d1ce4e5b9;
        x=(x^(x>>27) )*0x94d049bb133111eb;
        return x^(x>>31);
    }
    size_t operator() (uint64_t x) const {
        static const uint64_t FIXED_RANDOM=chrono::steady_clock::now().time_since_epoch().count();
        return splitmix64(x+FIXED_RANDOM);
    }
};
int read() {
    int x=0; bool _=true; char c=0;
    while (!isdigit(c) ) _&=(c!='-'),c=getchar();
    while (isdigit(c) ) x=x*10+(c&15),c=getchar();
    return _?x:-x;
}
int power(int a,int b=mod-2)
{
    int res=1;
    while (b) {
        if (b&1) res=res*a%mod;
        a=a*a%mod,b>>=1;
    }
    return res;
}
void FFT(Complex *a)
{
    for (int i=0;i<tot;i++)
    if (i<rev[i]) swap(a[i],a[rev[i] ]);
    for (int len=1;len<tot;len<<=1)
    {
        Complex w={cos(Pi/len),sin(Pi/len)};
        for (int i=0;i<tot;i+=len<<1) {
            Complex wk={1,0},x,y; for (int j=0;j<len;j++,wk=wk*w)
            x=a[i|j],y=a[i|len|j]*wk,a[i|j]=x+y,a[i|len|j]=x-y;
        }
    }
}
void IFFT(Complex *a) {
    FFT(a),reverse(a+1,a+tot);
    for (int i=0;i<tot;i++) a[i].x/=tot,a[i].y/=tot;
}
poly mul(poly a,poly b)
{
    int n=a.size()+b.size()-1; tot=1;
    int l=-1; while (tot<n) tot<<=1,l++;
    for (int i=0;i<tot;i++) A[i]=B[i]={0,0};
    for (int i=0;i<tot;i++) rev[i]=rev[i>>1]>>1|(i&1)<<l;
    for (int i=0;i<a.size();i++) A[i]={a[i]>>16,a[i]&65535};
    for (int i=0;i<b.size();i++) B[i]={b[i]>>16,b[i]&65535};
    FFT(A); for (int i=0;i<tot;i++) C[i]=A[(tot-i)%tot],C[i].y*=-1;
    FFT(B); for (int i=0;i<tot;i++) D[i]=B[(tot-i)%tot],D[i].y*=-1;
    for (int i=0;i<tot;i++) {
        auto x=A[i],y=B[i],p=C[i],q=D[i];
        A[i]=(p+x)*(q+y)*(Complex){0.25,0}+(p+x)*(q-y)*(Complex){-0.25,0};
        B[i]=(p-x)*(q+y)*(Complex){0,0.25}+(p-x)*(q-y)*(Complex){0,-0.25};
    }
    IFFT(A),IFFT(B),a.resize(n);
    for (int i=0;i<n;i++) {
        int x=round(A[i].x),y=round(A[i].y+B[i].x),z=round(B[i].y);
        a[i]=( (x%mod<<32)+(y%mod<<16)+z)%mod;
    }
    return a;
}
poly Lagrange(poly a,int l,int r)
{
    int m=a.size(); poly b;
    poly fact(r-l+m+1),ifact(r-l+m+1);
    if (l<m) {
        b.resize(r-l+1);
        for (int i=0;l+i<m&&l+i<=r;i++) b[i]=a[l+i];
        if (r<m) return b; poly c=Lagrange(a,m,r);
        for (int i=m;i<=r;i++) b[i-l]=c[i-m]; return b;
    }
    b.resize(r-l+m),fact[0]=ifact[0]=ifact[1]=1;
    for (int i=2;i<m;i++) ifact[i]=ifact[mod%i]*(mod-mod/i)%mod;
    for (int i=1;i<m;i++) ifact[i]=ifact[i-1]*ifact[i]%mod;
    for (int i=0;i<m;i++) {
        a[i]=a[i]*ifact[i]%mod*ifact[m-i-1]%mod;
        if ( (m-i-1)&1) a[i]=(mod-a[i])%mod;
    }
    for (int i=1;i<=r-l+m;i++) fact[i]=fact[i-1]*(i+l-m)%mod;
    ifact[r-l+m]=power(fact[r-l+m]);
    for (int i=r-l+m;i;i--) ifact[i-1]=ifact[i]*(i+l-m)%mod;
    for (int i=0;i<r-l+m;i++) b[i]=ifact[i+1]*fact[i]%mod;
    a=mul(a,b),b.resize(r-l+1); for (int i=0;i<=r-l;i++)
    b[i]=a[i+m-1]*fact[i+m]%mod*ifact[i]%mod; return b;
}
main()
{
    int T=read(),n;
    while (T--)
    {
        n=read(),mod=read(); int m=sqrt(n),res=1;
        poly a(2); a[0]=1,a[1]=(m+1)%mod;
        for (int i=__lg(m)-1,n=1;~i;i--)
        {
            int tmp=n*power(m)%mod; n<<=1;
            poly b=Lagrange(a,0,n),c=Lagrange(a,tmp,min(tmp+n,mod-1) );
            if (tmp+n>=mod) { poly d=Lagrange(a,0,(tmp+n)%mod); for (int x:d) c.push_back(x); }
            a.resize(n+1); for (int i=0;i<=n;i++) a[i]=b[i]*c[i]%mod;
            if (m>>i&1) {
                int val=1; n++;
                for (int i=0;i<n;i++)
                a[i]=a[i]*(i*m+n)%mod,
                val=val*(n*m+i+1)%mod;
                a.push_back(val);
            }
        }
        for (int i=0;i<m;i++) res=res*a[i]%mod;
        for (int i=m*m+1;i<=n;i++) res=res*i%mod;
        printf("%lld\n",res);
    }
    return 0;
}