p5282 题解

· · 题解

这是一篇非正解的题解,如想要学习 O(\sqrt n\log n) 的解法请自行略过。

不需要什么多项式,暴力是可以通过本题的。

题意:给定 n,p,求 n! \mod p 的值。

显然直接做是 O(n) 的。

数据范围2^{31},题目却提示复杂度小于 O(\sqrt n\log n)。但是人家 O(n^2) 都能过百万凭啥我 O(n) 不能过 2^{31}

于是考虑优化 O(n)。显然我们不能通过减少乘的次数来卡常,或者说通过某些方法减少乘的次数反而会导致变慢或者导致算法变高级

阶乘是一个非常适合循环展开的东西,我们可以把阶乘每 8 个一组分别计算。即分别计算 1\times 9\times 17\times \dots \times 8k+12\times 10\times 18\times \dots \times 8k+2 \dots 后全部乘起来。

可惜循环展开还是不够快,但是有一个比循环展开更快且更直接的批量操作的东西:指令集

这里我采用了 AVX2 指令集进行优化。AVX2 可以支持高效的 256 位向量运算,也就是可以同时运算 8 个整型。

由于指令集指令过于丰富复杂,下面只有代码中带有一些对指令的讲解。这是微软提供的SSE2指令集的说明,与AVX2指令类似。

AXV2 的一个变量占用 256 位空间,根据不同情况可以当作不同精度的浮点或整型使用,下面主要当作 8 个 int 使用。

它的整型支持大部分的位运算以及加减乘运算,但是不支持除法及取模,因此我们需要自己想办法实现取模。

下面有几种解决方案:

1:通过转换类型实现除法再进行取模。但是受到硬件限制 cpu 的除法运算极其慢,无法采用。

2:通过 barrett 约减实现取模。但是在本题数据范围内使用 barrett 约减将会涉及 AVX2 不支持的 128 位整数的运算,无法采用。

3:通过蒙哥马利约减实现取模。蒙哥马利约减可以仅在 64 位整数范围内完成取模且速度较快,因此我采用了蒙哥马利约减实现本题。

注:

barrett 约减:

对于 r=z\mod p 基于低成本的除法运算获得 q=[2^k/p],每次取模时先计算 x=q\times z/2^k\approx z/p,这样我们就通过一次乘法运算及一次位移运算高效的完成了一次不是很精确的除法。

接下来计算 r'=r-x\times p,在除法精确的情况下 r' 即为答案,但是由于我们完成的除法不精确,我们需要再进行一次校验,若 r' 大于 pr'=r'-p,在 k 足够大时 r' 即为所求。于是我们就通过两次乘法替换掉了一次取模,实际上编译器对于常量的取模优化就是基于此的。

但是由于 32 位整数作为模数时 k 通常需要取 60 以上,计算 x=q\times z/2^k\approx z/p 时需要用到 128 位整数,因此无法使用 barrett 约减优化此题。

蒙哥马利约减:

由于蒙哥马利约减证明及流程过于复杂,这里仅放一个百度文库链接。(后续可能会填坑?)

简单来说,蒙哥马利约减就是通过将整数转换到蒙哥马利数域使得取模可以使用两次乘法及一次位移实现,且对于 32 位的模数可以在 64 位内完成运算。

预处理进入了蒙哥马利数域的 1-8 并放入 AVX2 变量内,假设叫 ml ,同时预处理一个塞了 8 个 8 的 AVX2 变量,假设叫 ad ,答案预设为塞了 8 蒙哥马利数域的 1 的 AVX2 变量。

这样我们只需要每次将答案与 ml 相乘,再将 ml 加上 ad 即可。所有运算均在数域内完成,没有进出数域的时间浪费。

于是我们所需要的所有运算都可以实现了,终于可以上代码了。

#include<stdio.h>
#include<immintrin.h>
#pragma GCC target("avx2")
//__m256i为AVX2指令集的整型变量 
static unsigned mod,r,n2_;
static __m256i a0,mod1,R,hi32,ans,ml,ad;
long long exgcd(long long a,long long b,long long &x,long long &y)
{
    long long d=a;
    if(b==0) x=1,y=0;
    else d=exgcd(b,a%b,y,x),y-=a/b*x;
    return d;
}
inline unsigned mul(unsigned x,unsigned y)//蒙哥马利数域内的乘法及取模 
{
    unsigned long long z=(unsigned long long)x*y;//计算乘积 
    return (z+(unsigned long long)(unsigned(z)*r)*mod)>>32;//对p取模 
}
inline __m256i add(__m256i _num1,__m256i _num2)//将8个32位整数相加并取模 
{
    __m256i apb=_mm256_add_epi32(_num1,_num2),//将num1,num2内的8个整数对应相加 
    ret=_mm256_sub_epi32(apb,mod1);//将答案减去mod 
    __m256i cmp=_mm256_cmpgt_epi32(a0,ret),//得到答案内小于0的数 
    add=_mm256_and_si256(cmp,mod1);
    return _mm256_add_epi32(add,ret);//将小于0的数加mod 
}
inline __m256i mul(__m256i _num1,__m256i _num2)//将8个32位整数相乘并取模 
{
    __m256i _num3=_num1,_num4,_num5=_num2;
    _num2=_mm256_mul_epu32(_num1,_num2);//将偶数位的整数相乘(得到4个64位整数并放回) 
    _num1=_mm256_mul_epu32(_mm256_mul_epu32(_num2,R),mod1);
    _num4=_mm256_srli_epi64(_mm256_add_epi64(_num1,_num2),32);//每个64位整数均左移32位,空出高32位(奇数位) 
    _num1=_mm256_srli_si256(_num3,4);_num2=_mm256_srli_si256(_num5,4);//进行4字节位移,即将奇数位的整数移动到偶数位 
    _num2=_mm256_mul_epu32(_num1,_num2);//将原奇数位的整数相乘 
    _num1=_mm256_mul_epu32(_mm256_mul_epu32(_num2,R),mod1);
    _num1=_mm256_and_si256(_mm256_add_epi64(_num1,_num2),hi32);//通过跟hi32做与取出低32位(偶数位) 
    return _mm256_or_si256(_num1,_num4);//将偶数位乘积和奇数位乘积合并位一个AVX2变量返回 
}
inline unsigned mon_in(unsigned x){return mul(x,n2_);}//进入蒙哥马利数域 
inline unsigned mon_out(unsigned x)//离开蒙哥马利数域 
{unsigned ret=((x+(unsigned long long)(unsigned(x)*r)*mod)>>32);return ret<mod?ret:ret-mod;}
void solve(int n,int p)
{
    unsigned i=1;
    long long x,y;
    mod=p;
    n2_=-(unsigned long long)mod%mod;exgcd(mod,1ll<<32,x,y);r=-unsigned(x);//初始化一些蒙哥马利约减需要使用的变量
    a0=_mm256_setzero_si256(),//预处理全0的AVX2变量(做加法有用) 
    mod1=_mm256_set1_epi32(mod),R=_mm256_set1_epi32(r);//预处理全mod1,全r的AVX2变量 
    hi32=_mm256_set_epi32(-1,0,-1,0,-1,0,-1,0);//预处理奇数位全1的AVX2变量(做乘法有用) 
    ans=_mm256_set1_epi32(mon_in(1));//初始化ans 
    ad=_mm256_set1_epi32(mon_in(8));//初始化全8的AVX2变量 
    ml=_mm256_set_epi32(mon_in(8),mon_in(7),mon_in(6),mon_in(5),mon_in(4),mon_in(3),mon_in(2),mon_in(1));//初始化因数 
    for(;i+8<=n;i+=8)
    {
        ans=mul(ans,ml);//将ans的8位与因数的8位依次相乘 
        ml=add(ml,ad);//将因数8位全部加8 
    }
    unsigned *fl=(unsigned*)&ans;
    unsigned as=mul(mul(mul(fl[0],fl[1]),mul(fl[2],fl[3])),mul(mul(fl[4],fl[5]),mul(fl[6],fl[7])));//取出ans的8位并相乘 
    as=mon_out(as);//离开蒙哥马利数域 
    for(;i<=n;i++)//处理整段未做掉的末几位 
        as=1ull*as*i%mod;
    printf("%u\n",as);
}
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        int n,p;
        scanf("%d%d",&n,&p);
        solve(n,p);
    }
    return 0;
}

然后就愉快的切掉了一道黑题提交记录。其实时间卡的还是挺死的