题解:P5282 【模板】快速阶乘算法
_Kagamine_Rin_ · · 题解
0xFF 写在前面
和 bh1234666 一样,这篇题解也不是一篇正常的题解,本题解将教你怎样在
但与 bh1234666 的题解不同的是:
- 该题解可能更容易看懂
- 该题解会提供完整的代码(但请不要直接抄袭)
- 该题解会提供基于蒙哥马利约减(Montgomery Reduction)的
modint模板 - 该题解的代码没有使用任何指令集和
#pragma(这意味着你能在正式比赛上使用) - 由于是较为极限卡常的
O(n) 做法,所以可能会包含一些较为玄学的内容(比如最后一步把8\times 8 的数组改为8 个含有8 个unsigned int的结构体就能 快至少三倍)。
让我们开始吧~~
注意:
- 以下代码中省略了
T ,并增加了clock_t进行时间测试。 - 为了不浪费评测资源,以下测试如果时间超过 1.5s,则取本地结果。
- 为了防止代码冗长导致可读性下降,循环展开将放在最后使用。
- 如果没有给出特殊说明,则以下面这组数据测试:
Input
1000000000 2147483647Output
1289569604
0x00 朴素暴力
我们很容易就能写出一个
#include <stdio.h>
#include <time.h>
int main() {
int n, mod, ans = 1;
scanf("%d%d", &n, &mod);
clock_t st = clock();
for (int i = 1; i <= n; ++ i)
ans = 1ull * ans * i % mod;
printf("%u\n", ans);
clock_t ed = clock();
printf("time: %lf ms", (ed - st) * 1000.0 / CLOCKS_PER_SEC);
}
以下是运行结果(本地):
1289569604
time: 9608.000000 ms
此时可能不少人会决定使用循环展开,但实际上效果不大,因为取模运算(对变量取模)的端口需求很多,所以在计算模数的时候几乎会停止所有其它操作,导致几乎不能使 CPU 并行。所以我们需要将取模操作变为其它操作。
0x01 避免取模操作的方法
常见的方法有:
| 方法 | 当 int 范围内的特殊要求 |
|---|---|
| Barrett 约减 | 需要使用 uint128_t |
| Montgomery 约减 | 模数是奇数 |
| 梅森素数特殊取模 | 模数是梅森素数 |
| 查表法 | 模数较小 |
Barrett 约减显然可行。
由于 2 以外的质数肯定是奇数(当然本题测试点中没有 2),所以 Montgomery 约减也可行。
至于后两者,在本题根本不可能用到吧。
Barrett 约减
原理:Here.
该算法能够先得出 if/while 进行判断。
以下是代码实现:
#include <cstdio>
#include <time.h>
#include <stdint.h>
int main() {
uint32_t n, mod; uint64_t ans = 1;
scanf("%d%d", &n, &mod);
clock_t st = clock();
typedef unsigned __int128 uint128_t;
const uint128_t barret = (uint128_t(1) << 64) / mod;
for (uint32_t i = 1; i <= n; ++ i) {
ans *= i, ans -= mod * (barret * ans >> 64);
while (ans >= mod) ans -= mod;
}
printf("%llu\n", ans % mod);
clock_t ed = clock();
printf("time: %lf ms", (ed - st) * 1000.0 / CLOCKS_PER_SEC);
}
以下是运行结果(本地):
1289569604
time: 3270.000000 ms
实际上把第 13 行的 while 改成 if 不会改变程序运行结果,但运行时间会比原来多 500 ms。原因是 Barrett 约减的精度很高,几乎不会使条件成立,所以 CPU 对 if 语句做的所有分支预测几乎全部无效,导致几乎每次都要清空流水线。
Montgomery 约减
原理:Here.(找了很多篇好像讲得都不是很好)
下面是代码:
#include <cstdint>
#include <cstdio>
#include <ctime>
using u32 = uint32_t;
using i32 = int32_t;
using u64 = uint64_t;
using i64 = int64_t;
static u32 m, inv, r2;
// 使用扩展欧几里得算法计算模逆元
u32 getinv() {
i64 t0 = 0, t1 = 1;
u64 r0 = u64(1) << 32; // R = 2^32
i64 r1 = m;
while (r1 != 0) {
u64 q = r0 / r1;
i64 t2 = t0 - q * t1; t0 = t1; t1 = t2;
u64 r2 = r0 - q * r1; r0 = r1; r1 = r2;
}
if (r0 != 1) return 0; // 没有逆元
if (t0 < 0) t0 += (u64(1) << 32); // 确保结果为正
return t0;
}
struct Mont {
private:
u32 x;
public:
// 蒙哥马利约减
static u32 reduce(u64 x) {
u32 u = u32(x) * inv; // 计算 u = x * m' mod R
u64 v = x + u64(u) * m; // 计算 x + u * m
u32 res = v >> 32; // 除以 R
return res >= m ? res - m : res;
}
Mont() : x(0) {}
Mont(i32 x) : x(reduce(u64(x) * r2)) {} // 转换到蒙哥马利域
Mont& operator+=(const Mont& rhs) { x += rhs.x; if (x >= m) x -= m; return *this; }
Mont& operator-=(const Mont& rhs) { if (x < rhs.x) x += m; x -= rhs.x; return *this; }
Mont& operator*=(const Mont& rhs) { x = reduce(u64(x) * rhs.x); return *this; }
friend Mont operator+(Mont x, const Mont& y) { return x += y; }
friend Mont operator-(Mont x, const Mont& y) { return x -= y; }
friend Mont operator*(Mont x, const Mont& y) { return x *= y; }
i32 get() { return reduce(x); } // 转换回普通域
};
void Init(u32 modulus) {
m = modulus;
inv = -getinv(); // 计算 m' = -m^{-1} mod R
r2 = (-u64(m)) % m; // 计算 R^2 mod m
}
int main() {
int n, mod;
scanf("%d%d", &n, &mod);
Init(mod); Mont ans = 1;
clock_t st = clock();
for (int i = 1; i <= n; ++ i)
ans *= i;
printf("%u\n", ans.get());
clock_t ed = clock();
printf("time: %lf ms", (ed - st) * 1000.0 / CLOCKS_PER_SEC);
}
以下是运行结果(本地):
1289569604
time: 3580.000000 ms
可以发现,Montgomery 约减在一般情况下效率没有 Barrett 约减高,而且代码也更长。但我们能发现 Montgomery 约减的优点:
- 仅需使用 64 位整数(最关键的地方)
- 更适合作为
modint类
总结
由于 uint128_t 是由两个 uint64_t 拼接起来的,也就是说不能使用一个寄存器存下一个 uint128_t 变量,所以接下来我们的优化将基于 Montgomery 约减而不是 Barrett 约减。
0x02 阶乘的性质——常数还能减小
现在我们要求出
假设我们已经求出了
这样的方法能使我们节省大约
实际上我们还可以在求
于是我们就可以先将
0x03 循环展开——万恶的开始
对于卡常,大多数人第一个想到的是读入优化(然而在这道题一点用都没有),第二个是循环展开(在这道题里很有用)。
所谓循环展开,就是一种牺牲程序的尺寸来加快程序的执行速度的优化方法,可以由程序员完成,也可由编译器自动优化完成(但 O2 优化是不彻底的)。它通过将循环体内的代码复制多次的操作,进而减少循环分支指令执行的次数,增大处理器指令调度的空间,获得更多的指令级并行。
循环展开一般是这样的:
int sum = 0;
for (int i = 1; i <= n; ++ i) sum += a[i];
=>
int sum1 = 0, sum2 = 0, sum3 = 0, sum4 = 0, i;
for (i = 1; i + 3 <= n; i += 4)
sum1 += a[i], sum2 += a[i + 1], sum3 += a[i + 2], sum4 += a[i + 3];
for (; i <= n; ++ i) sum1 += a[i];
int sum = sum1 + sum2 + sum3 + sum4;
所以我们很容易就能让 DEEPSEEK 改写出一段展开了 Mont 类不利于我们进行优化,所以我去掉了):
由于代码太长,请转移至 https://www.luogu.me/paste/qp1t5863
以下是运行结果(本地):
1289569604
time: 756.000000 ms
以下是运行结果(洛谷):
1289569604
time: 767.867000 ms
可以看见,我们已经能在 1000 ms 解决一个问题了。但我们的规模只有
提交上去发现有 30 分
但我们稍微想一想就知道,其实我们展开 64 次操作如果不加上任何修饰相当于我们的信息是一维的:
op1, op2, op3, op4, ..., op64。
很显然计算机会发现自己无法同时并行这么多条运算,【所以可能会做 8 遍 8 个操作并行】(注:【】内的内容是为了方便理解,实际上计算机的操作不会这么简单),这也就说明,我们展开 64 次和 8 次是一样的……
吗?
0x04 别忘记结构体——结束了?
我们发现洛谷的评测机上面能够使用 #pragma GCC target("avx512f")。
这说明了洛谷的评测机在理论上能够一次性计算 max、min、>、< 等操作。
所以我们可以把我们循环展开的
(u32x8){op1, op2, op3, op4, op5, op6, op7, op8},
(u32x8){op9, op10, op1, op12, op13, op14, op15, op16},
...
(u32x8){op57, op58, op59, op60, op61, op62, op63, op64},
像这样,我们就可以诱导编译器帮助我们使用指令集。
代码:
#include <cstdint>
#include <cstdio>
#include <ctime>
using u32 = unsigned int;
using u64 = unsigned long long;
using i64 = long long;
static u32 mod, r, n2_;
// 扩展欧几里得算法
i64 exgcd(i64 a, i64 b, i64 &x, i64 &y) {
i64 d = a;
if (b == 0) x = 1, y = 0;
else d = exgcd(b, a % b, y, x), y -= a / b * x;
return d;
}
inline u32 mul(u32 x, u32 y) {
unsigned ret = (1ull * x * y + 1ull * (u32(1ull * x * y) * r) * mod) >> 32;
return ret < mod ? ret : ret - mod;
}
inline u32 add(u32 a, u32 b) { u32 res = a + b; return res >= mod ? res - mod : res; } // 蒙哥马利加法
// 使用结构体代替数组,帮助编译器优化
struct u32x8 {
u32 v[8];
u32x8() = default;
u32x8(u32 val) { for (int i = 0; i < 8; i++) v[i] = val; }
u32x8 operator+(const u32x8& other) const { u32x8 result; for (int i = 0; i < 8; i++) result.v[i] = add(v[i], other.v[i]); return result; }
u32x8 operator*(const u32x8& other) const { u32x8 result; for (int i = 0; i < 8; i++) result.v[i] = mul(v[i], other.v[i]); return result; }
u32x8& operator+=(const u32x8& other) { for (int i = 0; i < 8; i++) v[i] = add(v[i], other.v[i]); return *this; }
u32x8& operator*=(const u32x8& other) { for (int i = 0; i < 8; i++) v[i] = mul(v[i], other.v[i]); return *this; }
};
inline u32 mon_in(u32 x) { return mul(x, n2_); }
inline u32 mon_out(u32 x) { u32 ret = ((x + (u64)((u32)x * r) * mod) >> 32); return ret < mod ? ret : ret - mod; }
inline u32 qpow(u32 n, u32 m, u32 p) {
if (!m) return 1;
u32 ret = qpow(n, m >> 1, p);
ret = (u64)ret * ret % p;
if (m & 1) return (u64)ret * n % p;
else return ret;
}
void solve(int N, u32 p) {
const int mv = 8;
int n = N - (N & (256 * (1 << mv) - 1));
mod = p;
n2_ = -(u64)mod % mod;
i64 x, y;
exgcd(mod, 1ll << 32, x, y);
r = -u32(x);
u32 as = mon_in(1);
u32 as2 = mon_in(1);
// 使用结构体数组
u32x8 ans[8];
u32x8 ml[8];
u32x8 ad_val(mon_in(64));
for (int i = 0; i < 8; i++) ans[i] = u32x8(mon_in(1)); // 初始化ans
for (int i = 0; i < 8; i++) // 初始化ml
for (int j = 0; j < 8; j++)
ml[i].v[j] = mon_in(i * 8 + j + 1);
for (unsigned i = 1; i + 63 <= (n >> mv); i += 64) // 主循环 - 编译器应该能够自动向量化这个循环
for (int j = 0; j < 8; j++) {
ans[j] *= ml[j];
ml[j] += ad_val;
}
for (int i = 0; i < 8; i++) { // 分离奇偶项
u32 odd_prod = mon_in(1);
u32 even_prod = mon_in(1);
for (int j = 0; j < 8; j++)
if (j & 1) odd_prod = mul(odd_prod, ans[i].v[j]);
else even_prod = mul(even_prod, ans[i].v[j]);
as = mul(as, odd_prod);
as2 = mul(as2, even_prod);
}
for (int j = mv - 1; j >= 0; j--) { // 多轮优化
as = mul(as, as2);
as = mul(as, mon_in(qpow(2, n >> (j + 1), p)));
u32x8 inner_ans[8];
u32x8 inner_ml[8];
u32x8 inner_ad_val(mon_in(128));
const unsigned add_ = n >> (j + 1);
for (int i = 0; i < 8; i++) inner_ans[i] = u32x8(mon_in(1)); // 初始化inner_ans
for (int i = 0; i < 8; i++) // 初始化inner_ml
for (int k = 0; k < 8; k++)
inner_ml[i].v[k] = mon_in(add_ + i * 16 + k * 2 + 1);
for (unsigned i = add_; i + 127 <= (n >> j); i += 128) // 处理当前段
for (int k = 0; k < 8; k++) {
inner_ans[k] *= inner_ml[k];
inner_ml[k] += inner_ad_val;
}
for (int i = 0; i < 8; i++) { // 合并当前处理段后半奇数的乘积
u32 prod0 = mul(mul(inner_ans[i].v[0], inner_ans[i].v[1]), mul(inner_ans[i].v[2], inner_ans[i].v[3]));
u32 prod1 = mul(mul(inner_ans[i].v[4], inner_ans[i].v[5]), mul(inner_ans[i].v[6], inner_ans[i].v[7]));
as2 = mul(as2, mul(prod0, prod1));
}
}
// 将奇数偶数部分相乘得到答案
as = mul(as, as2);
as = mon_out(as);
for (int i = n + 1; i <= N; i++) as = (u64)as * i % p; // 暴力将最后一段乘上去
printf("%u\n", as);
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
int n, p;
scanf("%d%d", &n, &p);
solve(n, p);
}
return 0;
}
然后我们就 愉快地(其实是比调正解还痛苦地)不使用指令集,且用蓝题的知识(exgcd 是蓝) AC 掉了一道黑题!
Record.
可见没有比使用指令集的代码更快(用指令集能够再减少一半常数)
PS:不要用 GCC 9 提交,否则你会获得 10 分的高分(后续会修复这个问题)
0x05 还能对偶数取模——竟然没有结束
- 令模数
p = 2^x\times p' 。 - 计算
n!\bmod p' - 计算
n! 中因子2 的指数(n!=u\times 2^v 中的v ) - 计算
n! \bmod 2^k - 使用中国剩余定理合并结果
代码:自己写(有需要可以私信作者)
0x06 使用 avx2/avx512f 指令集
由于 bh1234666 已经写过了 avx2 的版本,那我就写一下 avx512f 的版本:
__m512i 变量能够同时存储/计算 uint64_t、uint32_t、uint16_t 或 uint8_t。
所以我们可以把本题的代码写成这样:
inline __m512i add(__m512i a, __m512i b) {
__m512i sum = _mm512_add_epi32(a, b); // 8 个 32 位整数计算加法
__mmask16 overflow_mask = _mm512_cmpge_epu32_mask(sum, mod1); // 这里有个容易错的地方,__m512 的 cmp 返回的是 16 位掩码,和 __m256i 的 cmp 返回值不同
return _mm512_mask_sub_epi32(sum, overflow_mask, sum, mod1); // 这里不能忘记判断有没有超过 mod,但好像 bh1234666 忘记判断了……
}
inline __m512i mul512(__m512i _num1, __m512i _num2) {
__m512i _num3 = _num1, _num4, _num5 = _num2;
_num2 = _mm512_mul_epu32(_num1, _num2); // 8 个 32 位整数计算乘法
_num1 = _mm512_mul_epu32(_mm512_mul_epu32(_num2, R), mod1);
_num4 = _mm512_srli_epi64(_mm512_add_epi64(_num1, _num2), 32);
_num1 = _mm512_bsrli_epi128(_num3, 4); // 每个 128 位通道右移 4 字节
_num2 = _mm512_bsrli_epi128(_num5, 4);
_num2 = _mm512_mul_epu32(_num1, _num2);
_num1 = _mm512_mul_epu32(_mm512_mul_epu32(_num2, R), mod1);
_num1 = _mm512_and_si512(_mm512_add_epi64(_num1, _num2), hi32);
__m512i result = _mm512_or_si512(_num1, _num4);
__mmask16 overflow_mask = _mm512_cmpge_epu32_mask(result, mod1);
return _mm512_mask_sub_epi32(result, overflow_mask, result, mod1); // 这里也不能忘记判断有没有超过 mod
}
这里 是全部代码(avx512f 优化)
0x07 写在最后
- 该题解的做法借鉴了 bh1234666 的做法(虽然 bh1234666 的
mul函数其实写错了)。 - 由于作者只是一个普通的竞赛生,所以对硬件的了解并不深刻,如果上文有错误,欢迎指出。
- 为了 AC 这道题,作者从某天的 23:00 开始卡常到了第二天的 7:00,所以千万要点赞之后再走,也希望管理员大大通过本篇题解。
- 实际上循环展开一般展开 2 或 4 或 8 次,题解中的展开
64 次只是作为一个示例(表明展开太多次实际上没有效果)。 - 很显然这篇题解的 AC 代码还有极大的提升空间,如果有更好的做法,或在 GCC 9 标准能通过的方法,也欢迎指出!