题解:P5282 【模板】快速阶乘算法

· · 题解

声明:本题解非 O(\sqrt{n}\log n) 正解,欲寻正解请勿参考此篇题解。

P5282 【模板】快速阶乘算法

有一天,\mathrm{Afishinsea} 在随机跳题时跳到了这题,并惊喜地发现了这篇暴力卡常题解,可惜的是,由于本题时限的修改,此篇题解再也无法通过(证据),于是 ta 决定尝试在不使用多项式的 O(T\sqrt{n}\log n)/O(T\sqrt{n}\log^2 n) 做法卡常通过此题。

回归正题,题目求 n!\ \mathrm{mod}\ p 其中 n,p 均为 int 范围内的正整数, p 为质数。

考虑对 n! 质因数分解再乘起来:

对于质数 p_i1\sim np_i 的倍数有 \lfloor\frac{n}{p_i}\rfloor 个,p_i^2 的倍数有 \lfloor\frac{n}{p_i^2}\rfloor 个,p_i^k 的倍数有 \lfloor\frac{n}{p_i^k}\rfloor 个。

对于 p_i ,我们可以先让最终答案乘以 p_i^{\lfloor\frac{n}{p_i}\rfloor+\lfloor\frac{n}{p_i^2}\rfloor+\lfloor\frac{n}{p_i^3}\rfloor+...}

说明:因为在计算 p_i^k 的倍数的贡献之前已经计算完了 p_i^j(j<k) 的贡献,故底数为 p_i 而不是 p_i^k

对于 k\ge2 ,因为直接算复杂度不高,比 O(\sqrt{n}) 略大,可以不优化,暴力计算贡献。

对于 k=1 ,注意到较大的质数中满足 \frac{n}{p_i}=\frac{n}{p_j}(i,j) 较多(说白了就是可以数论分块),可以将贡献的计算表示为 \prod\limits_{l\le r\wedge \frac{n}{p_l}=\frac{n}{p_r}} (p_l p_{l+1} ... p_r)^{\frac{n}{p_l}} ,配合快速幂,我们只需要预处理出 \le n 的质数 \mathrm{mod}\ p 的前缀积即可。

考虑到这些质数有 O(\frac{n}{\log n}) 个,暴力预处理需要 O(T\frac{n}{\log n}) 的时间,实际看大约是 5\times10^8long long 乘法+取模,因为取模开销较大,故使用 \mathrm{Barrett} 约减进行常数优化(具体内容较为玄学,可自行搜索资料)。

如何得到这么多质数?显然,我们需要一个足够优化的质数筛,尽管理论上讲 O(n) 的欧拉筛比 O\Big(n\log\big(\log(n)\big)\Big) 的埃氏筛更为优秀,但埃氏筛的潜力更多,经过 \mathrm{Wheel\ Factorization} 优化后的埃氏筛足以在 1000ms(实测洛谷 ide 在 C++14 with O2 的情况下大概筛了 600~700ms,实际以评测结果为准) 内筛完 2^{30}(约 10^9 )以内的所有质数。其大致思想如下:

正常埃氏筛有一个忽略偶数的优化,我们可以同理推出 3,5,7 等质数倍数的优化,本题代码中选取了 3,5,7,11,13 进行筛选,对 2 特殊处理。筛法具体实现上细节特别多且十分复杂,可以参考这题的题解。

我把这道黑题成功转化成了另一道黑题

小技巧:

code (码风玄学,望谅解)


#include <bits/stdc++.h>
using namespace std;

namespace prime_sieve{
    using uint=unsigned int;
    using ullong=unsigned long long;
    template<const uint size=0ull>
    struct bitset{
        ullong data[(size-1>>6)+1];
        ullong pow[64]={
            0x1ull,0x2ull,0x4ull,0x8ull,
            0x10ull,0x20ull,0x40ull,0x80ull,
            0x100ull,0x200ull,0x400ull,0x800ull,
            0x1000ull,0x2000ull,0x4000ull,0x8000ull,
            0x10000ull,0x20000ull,0x40000ull,0x80000ull,
            0x100000ull,0x200000ull,0x400000ull,0x800000ull,
            0x1000000ull,0x2000000ull,0x4000000ull,0x8000000ull,
            0x10000000ull,0x20000000ull,0x40000000ull,0x80000000ull,
            0x100000000ull,0x200000000ull,0x400000000ull,0x800000000ull,
            0x1000000000ull,0x2000000000ull,0x4000000000ull,0x8000000000ull,
            0x10000000000ull,0x20000000000ull,0x40000000000ull,0x80000000000ull,
            0x100000000000ull,0x200000000000ull,0x400000000000ull,0x800000000000ull,
            0x1000000000000ull,0x2000000000000ull,0x4000000000000ull,0x8000000000000ull,
            0x10000000000000ull,0x20000000000000ull,0x40000000000000ull,0x80000000000000ull,
            0x100000000000000ull,0x200000000000000ull,0x400000000000000ull,0x800000000000000ull,
            0x1000000000000000ull,0x2000000000000000ull,0x4000000000000000ull,0x8000000000000000ull
        };
        inline bitset(const ullong &x=0){
            reset();
            data[0]=x;
        }
        inline bool operator [](const int &pos){
            return data[pos>>6]&pow[pos&63];
        }
        inline void set(const int &pos){
            data[pos>>6]|=pow[pos&63];
        }
        inline void reset(const int &pos){
            data[pos>>6]&=~pow[pos&63];
        }
        inline void set(){
            memset(data,0xff,sizeof(data));
        }
        inline void reset(){
            memset(data,0x00,sizeof(data));
        }
    };

    const int prime_tot=54400028;
    const int max_primes=160000;
    const int sieve_span=1<<22;
    const int sieve_words=sieve_span>>7;
    const int wheel_size=3*5*7*11*13;

    bitset<sieve_words<<6> sieve;
    bitset<wheel_size<<6> pattern;
    int primes[max_primes],mcnt;
    int all_prime[prime_tot+sieve_span],pcnt;

    inline void pre_sieve(){
        for(int i=3;i<1024;i+=2){
            if(!sieve[i>>1]){
                for(int j=(i*i>>1);j<(1<<20);j+=i){
                    sieve.set(j);
                }
            }
        }
        for(int i=8;i<(1<<20);i++){
            if(!sieve[i]){
                primes[mcnt++]=i<<1|1;
            }
        }
        for(int i=1;i<wheel_size<<6;i+=3) pattern.set(i);
        for(int i=2;i<wheel_size<<6;i+=5) pattern.set(i);
        for(int i=3;i<wheel_size<<6;i+=7) pattern.set(i);
        for(int i=5;i<wheel_size<<6;i+=11) pattern.set(i);
        for(int i=6;i<wheel_size<<6;i+=13) pattern.set(i);
    }

    inline void update_sieve(int base){
        int tmp=base%wheel_size;
        tmp=(tmp+((tmp*105)&127)*wheel_size)>>7; // 105*wheel_size%128=127
        for(int i=0,k;i<sieve_words;i+=k,tmp=0){
            k=min(wheel_size-tmp,sieve_words-i);
            memcpy(sieve.data+i,pattern.data+tmp,k<<3);
        }
        if(base==0){
            sieve.data[0]|=1;
            sieve.data[0]&=~(0b1101110);
        }
        for(int i=0;i<mcnt;i++){
            long long j=primes[i]*primes[i];
            if(j>base+sieve_span-1) break;
            if(j>base) j=(j-base)>>1;
            else{
                j=primes[i]-base%primes[i];
                if(!(j&1)) j+=primes[i];
                j>>=1;
            }
            while(j<sieve_span>>1){
                sieve.set(j);
                j+=primes[i];
            }
        }
    }

    inline void segment_sieve(int base,int lim){
        update_sieve(base);
        int u=min(base+sieve_span,lim);
        for(int i=0;i<sieve_words;i++){
            ullong tmp=~sieve.data[i];
            while(tmp){
                int p=__builtin_ctzll(tmp);
                int u=base+(i<<7)+(p<<1)+1;
                if(u>=lim) break;
                all_prime[pcnt++]=u;
                tmp-=tmp&-tmp;
            }
        }
    }

    inline void fast_sieve(int lim) {
        pre_sieve();
        all_prime[pcnt++]=2;
        for(int base=0;base<lim;base+=sieve_span){
            segment_sieve(base,lim);
        }
    }

    #define prime all_prime
}
using namespace std;
using ll=long long;
using ld=long double;
using prime_sieve::prime;
const int maxn=1<<30;
const int ptot=54400028;
int T,n[5],p[5]={1,1,1,1,1}; // 不读入默认模 1 避免 RE
struct node{
    ll v[5];
    inline node(int vv=0){
        v[0]=v[1]=v[2]=v[3]=v[4]=vv;
    }
} prod[(ptot>>6)+114];

inline void prod_init(){
    __uint128_t brt[5]; // barrett 约减
    brt[0]=((__uint128_t)1<<64)/p[0];
    brt[1]=((__uint128_t)1<<64)/p[1];
    brt[2]=((__uint128_t)1<<64)/p[2];
    brt[3]=((__uint128_t)1<<64)/p[3];
    brt[4]=((__uint128_t)1<<64)/p[4];

    prod[0]=node(2);
    for(int i=64;i<ptot;i+=64){
        prod[i>>6]=prod[(i>>6)-1];
        for(int j=i-63;j<=i;j++){
            prod[i>>6].v[0]=prod[i>>6].v[0]*prime[j];
            prod[i>>6].v[1]=prod[i>>6].v[1]*prime[j];
            prod[i>>6].v[2]=prod[i>>6].v[2]*prime[j];
            prod[i>>6].v[3]=prod[i>>6].v[3]*prime[j];
            prod[i>>6].v[4]=prod[i>>6].v[4]*prime[j];

            prod[i>>6].v[0]=prod[i>>6].v[0]-p[0]*(brt[0]*prod[i>>6].v[0]>>64);
            prod[i>>6].v[1]=prod[i>>6].v[1]-p[1]*(brt[1]*prod[i>>6].v[1]>>64);
            prod[i>>6].v[2]=prod[i>>6].v[2]-p[2]*(brt[2]*prod[i>>6].v[2]>>64);
            prod[i>>6].v[3]=prod[i>>6].v[3]-p[3]*(brt[3]*prod[i>>6].v[3]>>64);
            prod[i>>6].v[4]=prod[i>>6].v[4]-p[4]*(brt[4]*prod[i>>6].v[4]>>64);

            while(prod[i>>6].v[0]>=p[0]) prod[i>>6].v[0]-=p[0];
            while(prod[i>>6].v[1]>=p[1]) prod[i>>6].v[1]-=p[1];
            while(prod[i>>6].v[2]>=p[2]) prod[i>>6].v[2]-=p[2];
            while(prod[i>>6].v[3]>=p[3]) prod[i>>6].v[3]-=p[3];
            while(prod[i>>6].v[4]>=p[4]) prod[i>>6].v[4]-=p[4];
        }
    }
}

inline int get_prod(int u,int v){
    /*
    int r=1;
    for(int i=0;i<=v;i++){
        r=1ll*r*prime[i]%p[u];
    }
    return r;
    */
    int ret=prod[v>>6].v[u];
    for(int i=1;i<=(v&63);i++){
        ret=(ll)ret*prime[i+((v>>6)<<6)]%p[u];
    }
    return ret;
}

inline ll quick_power(ll x,int y,int p){
    ll r=1;
    while(y){
        if(y&1) (r*=x)%=p;
        (x*=x)%=p,y>>=1;
    }
    return r;
}

inline ll factorial(int n,int u){
    ll ret=1;
    for(int i=0;prime[i]*prime[i]<=n;i++){
        ll j=prime[i]*prime[i];
        while(j<=n){
            (ret*=quick_power(prime[i],n/j,p[u]))%=p[u];
            j*=prime[i];
        }
    }
    int l=1,r;
    using prime_sieve::pcnt;
    while(l<=n){
        int r=n/(n/l);
        int p1=lower_bound(prime,prime+pcnt,l)-prime;
        int p2=upper_bound(prime,prime+pcnt,r)-1-prime;
        int pp1=p1?quick_power(get_prod(u,p1-1),n/l,p[u]):1;
        int pp2=~p2?quick_power(get_prod(u,p2),n/l,p[u]):1;
        (ret*=1ll*pp2*quick_power(pp1,p[u]-2,p[u])%p[u])%=p[u];
        l=r+1;
    }
    return ret;
}

inline int solve(int u){
    if(n[u]>=p[u]) return 0;
    else if(n[u]<=(p[u]-1-n[u])) return factorial(n[u],u);
    else return quick_power(factorial(p[u]-1-n[u],u),p[u]-2,p[u])*
                quick_power(p[u]-1,p[u]-n[u],p[u])%p[u]; // 威尔逊定理优化
}

signed main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr),cout.tie(nullptr);
    prime_sieve::fast_sieve(maxn);
    //for(int i=0;i<=10;i++) cout<<prime[i]<<' '; cout<<endl;
    cin>>T;
    for(int i=0;i<T;i++){
        cin>>n[i]>>p[i];
    }
    prod_init();
    for(int i=0;i<T;i++){
        cout<<solve(i)<<'\n';
    }
    return 0;
}

洛谷的评测机太慢了,在 Atcoder 不用 Barrett 约减优化也不会 TLE

提交记录(挺惊险的,感觉还可以再优化)