题解:P5282 【模板】快速阶乘算法

· · 题解

0xFF 写在前面

和 bh1234666 一样,这篇题解也不是一篇正常的题解,本题解将教你怎样在 O(n) 的基础上进行一步一步进行常数优化通过这道题。

但与 bh1234666 的题解不同的是:

让我们开始吧~~

注意:

  1. 以下代码中省略了 T,并增加了 clock_t 进行时间测试。
  2. 为了不浪费评测资源,以下测试如果时间超过 1.5s,则取本地结果。
  3. 为了防止代码冗长导致可读性下降,循环展开将放在最后使用。
  4. 如果没有给出特殊说明,则以下面这组数据测试:
    Input
    1000000000 2147483647
    Output
    1289569604

0x00 朴素暴力

我们很容易就能写出一个 O(n) 的暴力。

#include <stdio.h>
#include <time.h>
int main() {
    int n, mod, ans = 1;
    scanf("%d%d", &n, &mod);
    clock_t st = clock();

    for (int i = 1; i <= n; ++ i)
        ans = 1ull * ans * i % mod;
    printf("%u\n", ans);

    clock_t ed = clock();
    printf("time: %lf ms", (ed - st) * 1000.0 / CLOCKS_PER_SEC);
}

以下是运行结果(本地):

1289569604
time: 9608.000000 ms

此时可能不少人会决定使用循环展开,但实际上效果不大,因为取模运算(对变量取模)的端口需求很多,所以在计算模数的时候几乎会停止所有其它操作,导致几乎不能使 CPU 并行。所以我们需要将取模操作变为其它操作。

0x01 避免取模操作的方法

常见的方法有:

方法 modint 范围内的特殊要求
Barrett 约减 需要使用 uint128_t
Montgomery 约减 模数是奇数
梅森素数特殊取模 模数是梅森素数
查表法 模数较小

Barrett 约减显然可行。

由于 2 以外的质数肯定是奇数(当然本题测试点中没有 2),所以 Montgomery 约减也可行。

至于后两者,在本题根本不可能用到吧。

Barrett 约减

原理:Here.

该算法能够先得出 x \bmod p 的近似值(与正确的值之多相差 1),然后再用 if/while 进行判断。

以下是代码实现:

#include <cstdio>
#include <time.h>
#include <stdint.h>
int main() {
    uint32_t n, mod; uint64_t ans = 1;
    scanf("%d%d", &n, &mod);
    clock_t st = clock();

    typedef unsigned __int128 uint128_t;
    const uint128_t barret = (uint128_t(1) << 64) / mod;
    for (uint32_t i = 1; i <= n; ++ i) {
        ans *= i, ans -= mod * (barret * ans >> 64);
        while (ans >= mod) ans -= mod;
    }
    printf("%llu\n", ans % mod);

    clock_t ed = clock();
    printf("time: %lf ms", (ed - st) * 1000.0 / CLOCKS_PER_SEC);
}

以下是运行结果(本地):

1289569604
time: 3270.000000 ms

实际上把第 13 行的 while 改成 if 不会改变程序运行结果,但运行时间会比原来多 500 ms。原因是 Barrett 约减的精度很高,几乎不会使条件成立,所以 CPU 对 if 语句做的所有分支预测几乎全部无效,导致几乎每次都要清空流水线。

Montgomery 约减

原理:Here.(找了很多篇好像讲得都不是很好)

下面是代码:

#include <cstdint>
#include <cstdio>
#include <ctime>

using u32 = uint32_t;
using i32 = int32_t;
using u64 = uint64_t;
using i64 = int64_t;

static u32 m, inv, r2;

// 使用扩展欧几里得算法计算模逆元
u32 getinv() {
    i64 t0 = 0, t1 = 1;
    u64 r0 = u64(1) << 32; // R = 2^32
    i64 r1 = m;
    while (r1 != 0) {
        u64 q = r0 / r1;
        i64 t2 = t0 - q * t1; t0 = t1; t1 = t2;
        u64 r2 = r0 - q * r1; r0 = r1; r1 = r2;
    }
    if (r0 != 1) return 0; // 没有逆元
    if (t0 < 0) t0 += (u64(1) << 32); // 确保结果为正
    return t0;
}

struct Mont {
private:
    u32 x;

public:
    // 蒙哥马利约减
    static u32 reduce(u64 x) {
        u32 u = u32(x) * inv; // 计算 u = x * m' mod R
        u64 v = x + u64(u) * m; // 计算 x + u * m
        u32 res = v >> 32; // 除以 R
        return res >= m ? res - m : res;
    }

    Mont() : x(0) {}
    Mont(i32 x) : x(reduce(u64(x) * r2)) {} // 转换到蒙哥马利域

    Mont& operator+=(const Mont& rhs) { x += rhs.x; if (x >= m) x -= m; return *this; }
    Mont& operator-=(const Mont& rhs) { if (x < rhs.x) x += m; x -= rhs.x; return *this; }
    Mont& operator*=(const Mont& rhs) { x = reduce(u64(x) * rhs.x); return *this; }
    friend Mont operator+(Mont x, const Mont& y) { return x += y; }
    friend Mont operator-(Mont x, const Mont& y) { return x -= y; }
    friend Mont operator*(Mont x, const Mont& y) { return x *= y; }
    i32 get() { return reduce(x); } // 转换回普通域
};

void Init(u32 modulus) {
    m = modulus;
    inv = -getinv(); // 计算 m' = -m^{-1} mod R
    r2 = (-u64(m)) % m; // 计算 R^2 mod m
}

int main() {
    int n, mod;
    scanf("%d%d", &n, &mod);
    Init(mod); Mont ans = 1;
    clock_t st = clock();

    for (int i = 1; i <= n; ++ i)
        ans *= i;
    printf("%u\n", ans.get());

    clock_t ed = clock();
    printf("time: %lf ms", (ed - st) * 1000.0 / CLOCKS_PER_SEC);
}

以下是运行结果(本地):

1289569604
time: 3580.000000 ms

可以发现,Montgomery 约减在一般情况下效率没有 Barrett 约减高,而且代码也更长。但我们能发现 Montgomery 约减的优点:

总结

由于 uint128_t 是由两个 uint64_t 拼接起来的,也就是说不能使用一个寄存器存下一个 uint128_t 变量,所以接下来我们的优化将基于 Montgomery 约减而不是 Barrett 约减。

0x02 阶乘的性质——常数还能减小

现在我们要求出 n!

假设我们已经求出了 \lfloor \frac n 2 \rfloor !,那么 n 以内偶数的乘积就是 \lfloor \frac n 2 \rfloor !\times 2^{\lfloor \frac n 2 \rfloor}。而且我们在计算 \lfloor \frac n 2 \rfloor ! 的时候已经算出了 1\sim \lfloor \frac n 2 \rfloor 中奇数的乘积,所以我们只需再求 \lfloor \frac n 2 \rfloor \sim n 中的奇数乘积就好了。

这样的方法能使我们节省大约 \frac 1 4 的常数。

实际上我们还可以在求 \lfloor \frac n 2 \rfloor ! 的时候也使用上面的方法进行优化,又可以节省 \frac 1 2 的常数。

于是我们就可以先将 n 缩小为 2^x(x\in \N),然后用朴素暴力处理剩余部分。当我们像这样进行了 k 轮操作时,理论上就能使常数变为原来的 \frac 1{2^{t + 1}}。但实际上我们并不能做太多次这样的操作,这样会导致 2^x 非常小,而剩余部分很多,反而会让常数增加。所以我们只要做 8 次就可以了。

0x03 循环展开——万恶的开始

对于卡常,大多数人第一个想到的是读入优化(然而在这道题一点用都没有),第二个是循环展开(在这道题里很有用)。

所谓循环展开,就是一种牺牲程序的尺寸来加快程序的执行速度的优化方法,可以由程序员完成,也可由编译器自动优化完成(但 O2 优化是不彻底的)。它通过将循环体内的代码复制多次的操作,进而减少循环分支指令执行的次数,增大处理器指令调度的空间,获得更多的指令级并行。

循环展开一般是这样的:

int sum = 0;
for (int i = 1; i <= n; ++ i) sum += a[i];
=>
int sum1 = 0, sum2 = 0, sum3 = 0, sum4 = 0, i;
for (i = 1; i + 3 <= n; i += 4)
  sum1 += a[i], sum2 += a[i + 1], sum3 += a[i + 2], sum4 += a[i + 3];
for (; i <= n; ++ i) sum1 += a[i];
int sum = sum1 + sum2 + sum3 + sum4;

所以我们很容易就能让 DEEPSEEK 改写出一段展开了 64 遍的代码(框架借鉴了 bh1234666,但没有使用指令集)(由于 Mont 类不利于我们进行优化,所以我去掉了):

由于代码太长,请转移至 https://www.luogu.me/paste/qp1t5863

以下是运行结果(本地):

1289569604
time: 756.000000 ms

以下是运行结果(洛谷):

1289569604
time: 767.867000 ms

可以看见,我们已经能在 1000 ms 解决一个问题了。但我们的规模只有 10^9,所以当 T=5,n=2\times 10^9 时,我们的程序还是要跑 7000 ms 左右,寄!

提交上去发现有 30 分

但我们稍微想一想就知道,其实我们展开 64 次操作如果不加上任何修饰相当于我们的信息是一维的: op1, op2, op3, op4, ..., op64

很显然计算机会发现自己无法同时并行这么多条运算,【所以可能会做 8 遍 8 个操作并行】(注:【】内的内容是为了方便理解,实际上计算机的操作不会这么简单),这也就说明,我们展开 64 次和 8 次是一样的……

吗?

0x04 别忘记结构体——结束了?

我们发现洛谷的评测机上面能够使用 #pragma GCC target("avx512f")

这说明了洛谷的评测机在理论上能够一次性计算 864 位整数的加、减、乘、与、或、非、maxmin>< 等操作。

所以我们可以把我们循环展开的 64 个数进行 8\times 8 分组,让信息变成二维?

(u32x8){op1, op2, op3, op4, op5, op6, op7, op8},
(u32x8){op9, op10, op1, op12, op13, op14, op15, op16},
...
(u32x8){op57, op58, op59, op60, op61, op62, op63, op64},

像这样,我们就可以诱导编译器帮助我们使用指令集。

代码:

#include <cstdint>
#include <cstdio>
#include <ctime>

using u32 = unsigned int;
using u64 = unsigned long long;
using i64 = long long;

static u32 mod, r, n2_;

// 扩展欧几里得算法
i64 exgcd(i64 a, i64 b, i64 &x, i64 &y) {
    i64 d = a;
    if (b == 0) x = 1, y = 0;
    else d = exgcd(b, a % b, y, x), y -= a / b * x;
    return d;
}
inline u32 mul(u32 x, u32 y) {
    unsigned ret = (1ull * x * y + 1ull * (u32(1ull * x * y) * r) * mod) >> 32;
    return ret < mod ? ret : ret - mod;
}
inline u32 add(u32 a, u32 b) { u32 res = a + b; return res >= mod ? res - mod : res; } // 蒙哥马利加法

// 使用结构体代替数组,帮助编译器优化
struct u32x8 {
    u32 v[8];

    u32x8() = default;
    u32x8(u32 val) { for (int i = 0; i < 8; i++) v[i] = val; }
    u32x8 operator+(const u32x8& other) const { u32x8 result; for (int i = 0; i < 8; i++) result.v[i] = add(v[i], other.v[i]); return result; }
    u32x8 operator*(const u32x8& other) const { u32x8 result; for (int i = 0; i < 8; i++) result.v[i] = mul(v[i], other.v[i]); return result; }
    u32x8& operator+=(const u32x8& other) { for (int i = 0; i < 8; i++) v[i] = add(v[i], other.v[i]); return *this; }
    u32x8& operator*=(const u32x8& other) { for (int i = 0; i < 8; i++) v[i] = mul(v[i], other.v[i]); return *this; }
};

inline u32 mon_in(u32 x) { return mul(x, n2_); }

inline u32 mon_out(u32 x) { u32 ret = ((x + (u64)((u32)x * r) * mod) >> 32); return ret < mod ? ret : ret - mod; }

inline u32 qpow(u32 n, u32 m, u32 p) {
    if (!m) return 1;
    u32 ret = qpow(n, m >> 1, p);
    ret = (u64)ret * ret % p;
    if (m & 1) return (u64)ret * n % p;
    else return ret;
}

void solve(int N, u32 p) {
    const int mv = 8;
    int n = N - (N & (256 * (1 << mv) - 1));

    mod = p;
    n2_ = -(u64)mod % mod;

    i64 x, y;
    exgcd(mod, 1ll << 32, x, y);
    r = -u32(x);

    u32 as = mon_in(1);
    u32 as2 = mon_in(1);

    // 使用结构体数组
    u32x8 ans[8];
    u32x8 ml[8];
    u32x8 ad_val(mon_in(64));

    for (int i = 0; i < 8; i++) ans[i] = u32x8(mon_in(1)); // 初始化ans
    for (int i = 0; i < 8; i++) // 初始化ml
        for (int j = 0; j < 8; j++)
            ml[i].v[j] = mon_in(i * 8 + j + 1);

    for (unsigned i = 1; i + 63 <= (n >> mv); i += 64) // 主循环 - 编译器应该能够自动向量化这个循环
        for (int j = 0; j < 8; j++) {
            ans[j] *= ml[j];
            ml[j] += ad_val;
        }

    for (int i = 0; i < 8; i++) { // 分离奇偶项
        u32 odd_prod = mon_in(1);
        u32 even_prod = mon_in(1);

        for (int j = 0; j < 8; j++)
            if (j & 1) odd_prod = mul(odd_prod, ans[i].v[j]);
            else even_prod = mul(even_prod, ans[i].v[j]);

        as = mul(as, odd_prod);
        as2 = mul(as2, even_prod);
    }

    for (int j = mv - 1; j >= 0; j--) { // 多轮优化
        as = mul(as, as2);
        as = mul(as, mon_in(qpow(2, n >> (j + 1), p)));

        u32x8 inner_ans[8];
        u32x8 inner_ml[8];
        u32x8 inner_ad_val(mon_in(128));

        const unsigned add_ = n >> (j + 1);
        for (int i = 0; i < 8; i++) inner_ans[i] = u32x8(mon_in(1)); // 初始化inner_ans
        for (int i = 0; i < 8; i++) // 初始化inner_ml
            for (int k = 0; k < 8; k++)
                inner_ml[i].v[k] = mon_in(add_ + i * 16 + k * 2 + 1);
        for (unsigned i = add_; i + 127 <= (n >> j); i += 128) // 处理当前段
            for (int k = 0; k < 8; k++) {
                inner_ans[k] *= inner_ml[k];
                inner_ml[k] += inner_ad_val;
            }
        for (int i = 0; i < 8; i++) { // 合并当前处理段后半奇数的乘积
            u32 prod0 = mul(mul(inner_ans[i].v[0], inner_ans[i].v[1]), mul(inner_ans[i].v[2], inner_ans[i].v[3]));
            u32 prod1 = mul(mul(inner_ans[i].v[4], inner_ans[i].v[5]), mul(inner_ans[i].v[6], inner_ans[i].v[7]));
            as2 = mul(as2, mul(prod0, prod1));
        }
    }
    // 将奇数偶数部分相乘得到答案
    as = mul(as, as2);
    as = mon_out(as);
    for (int i = n + 1; i <= N; i++) as = (u64)as * i % p; // 暴力将最后一段乘上去
    printf("%u\n", as);
}

int main() {
    int T;
    scanf("%d", &T);
    while (T--) {
        int n, p;
        scanf("%d%d", &n, &p);
        solve(n, p);
    }
    return 0;
}

然后我们就 愉快地(其实是比调正解还痛苦地)不使用指令集,且用蓝题的知识(exgcd 是蓝) AC 掉了一道黑题!

Record.

可见没有比使用指令集的代码更快(用指令集能够再减少一半常数)

PS:不要用 GCC 9 提交,否则你会获得 10 分的高分(后续会修复这个问题)

0x05 还能对偶数取模——竟然没有结束

代码:自己写(有需要可以私信作者)

0x06 使用 avx2/avx512f 指令集

由于 bh1234666 已经写过了 avx2 的版本,那我就写一下 avx512f 的版本:

__m512i 变量能够同时存储/计算 8uint64_t16uint32_t32uint16_t64uint8_t

所以我们可以把本题的代码写成这样:

inline __m512i add(__m512i a, __m512i b) {
    __m512i sum = _mm512_add_epi32(a, b); // 8 个 32 位整数计算加法
    __mmask16 overflow_mask = _mm512_cmpge_epu32_mask(sum, mod1); // 这里有个容易错的地方,__m512 的 cmp 返回的是 16 位掩码,和 __m256i 的 cmp 返回值不同
    return _mm512_mask_sub_epi32(sum, overflow_mask, sum, mod1); // 这里不能忘记判断有没有超过 mod,但好像 bh1234666 忘记判断了……
}
inline __m512i mul512(__m512i _num1, __m512i _num2) {
    __m512i _num3 = _num1, _num4, _num5 = _num2;
    _num2 = _mm512_mul_epu32(_num1, _num2); // 8 个 32 位整数计算乘法
    _num1 = _mm512_mul_epu32(_mm512_mul_epu32(_num2, R), mod1);
    _num4 = _mm512_srli_epi64(_mm512_add_epi64(_num1, _num2), 32);
    _num1 = _mm512_bsrli_epi128(_num3, 4);  // 每个 128 位通道右移 4 字节
    _num2 = _mm512_bsrli_epi128(_num5, 4);
    _num2 = _mm512_mul_epu32(_num1, _num2);
    _num1 = _mm512_mul_epu32(_mm512_mul_epu32(_num2, R), mod1);
    _num1 = _mm512_and_si512(_mm512_add_epi64(_num1, _num2), hi32);
    __m512i result = _mm512_or_si512(_num1, _num4);
    __mmask16 overflow_mask = _mm512_cmpge_epu32_mask(result, mod1);
    return _mm512_mask_sub_epi32(result, overflow_mask, result, mod1); // 这里也不能忘记判断有没有超过 mod
}

这里 是全部代码(avx512f 优化)

0x07 写在最后