浅谈整数快速取模

· · 算法·理论

我们对于固定的模数 p,往往有一些通过位运算和预处理对取模的优化。

barrett 约减

问题:给定除数 M\in \N 和被除数的范围 [1, N]\cap\N,快速进行除法。

容易注意到 \dfrac{x}{M} = \dfrac{x}{2^k} \times \dfrac{2^k}{M} \approx \dfrac{x}{2^k} \times \left \lfloor \dfrac{2^k}{M}\right \rfloor,那么我们考虑预处理 I = \left \lfloor \dfrac{2^k}{M}\right \rfloor

除法就转化为了乘法和位移。符合问题的要求。现在问题是如何取 k 的值保证取整数部分时正确。

经过一些推导发现当 2^k > N(M - 1) 时可以保证 \left \lfloor \dfrac{xI}{2^k}\right \rfloor = \left \lfloor \dfrac{x}{M}\right \rfloor2^k > N 时误差 \le 1

然后就可以利用它优化取模运算了。然而这个东西在固定模数的时候编译器会自动加优化,而且 barrett 在大多数时候跑不过 Montgomery 模。只有在输入模数的题里边可以卡卡常数。

但是这个好背,考场上也随便写的。

::::info[一个 barrett 的实现]

class barrett {
 private:
  int64_t I;
  int32_t M;
 public:
  barrett(uint32_t p) : 
    I(((__int128)1 << 64) / p), M(p) {assert(p && p <= (int)1e9 + 9);}
  uint32_t operator ()(uint64_t x) {
    return x - ((__int128(x) * I) >> 64) * M;
  }
}; // author : Drifty

::::

Montgomery 模

我们引入一个参数 R,处于一些需要,你需要保证 \gcd(R, p) = 1, R > p

一个数 x\bmod\ p 意义下的 Montgomery 形式就是 aR \bmod p

首先 Montgomery 域上的加法和减法可以直接正常做。

考虑乘法怎么办,我们发现直接乘会变成 abR^2 这样子。那么我们引入 Montgomery 约减。具体的,一个数 a(0\le a< pR) 经过 Montgomery 约减之后就会变成 a\times R^{-1}\pmod p

我们记 p'p\equiv -1\pmod Rp'x \bmod R = t,约减的第一步是依赖于下面这个等式:

xR^{-1}\equiv \frac{x + pt}{R}\pmod p

我们下面进行证明。

  1. 证明:x + pt \equiv 0 \pmod R

    我们将上面的定义代入有:

    x + pt \equiv x + pp'x \equiv x - x = 0 \pmod R

    证毕。

  2. 证明:xR^{-1}\equiv \frac{x + pt}{R}\pmod p

    \begin{aligned} xR^{-1} &\equiv \frac{x + pt}{R}\pmod p \iff \\ xR^{-1}R &\equiv x + pt\pmod p \iff \\ x &\equiv x\pmod p \end{aligned}

    证毕。

网上蛮多文章要么没给证明,要么定义的正负号不清楚,其实这个东西并不难证明。

那么我们先计算出 \dfrac{x + pt}{R} = M,这就是第一步。

第二步是判断一下大小,如果比 p 大就 M\leftarrow M - p

值得发现的是因为 x < pR, t < R,所以有 \dfrac{x + pt}{R} < \dfrac{2pR}{R},最多减一次 p 就能保证 M < p

然后我们在实际计算的时候只要取 R = 2^{32} 次,就可以让这里所有的除法都变成位移,加快速度。但是注意,这样子做的时候模数 p 必须为奇数,否则不满足 \gcd(R, p) = 1,这会导致 p' 可能不存在。

那么乘法的结果 abR^2,只要做一次约减就变成了 abR。同理,你也可以用这个完成快速幂,除法,将 Montgomery 形式转为普通整数等一系列操作。

算法竞赛中,该算法大多数情况下具有比 barrett 更好的效率。

::::info[一个野生的 Montgomery 模板,貌似是 Min25 老师写的]

template <std::uint32_t P> struct MontgomeryModInt32 {
public:
  using i32 = std::int32_t;
  using u32 = std::uint32_t;
  using i64 = std::int64_t;
  using u64 = std::uint64_t;

private:
  u32 v;

  static constexpr u32 get_r() {
    u32 iv = P;

    for (u32 i = 0; i != 4; ++i)
      iv *= 2 - P * iv;

    return iv;
  }

  static constexpr u32 r = -get_r(), r2 = -u64(P) % P;

  static_assert((P & 1) == 1);
  static_assert(r * P == -1);
  static_assert(P < (1 << 30));

public:
  static constexpr u32 pow_mod(u32 x, u64 y) {
    if ((y %= P - 1) < 0)
      y += P - 1;

    u32 res = 1;

    for (; y != 0; y >>= 1, x = u64(x) * x % P)
      if (y & 1)
        res = u64(res) * x % P;

    return res;
  }

  static constexpr u32 get_pr() {
    u32 tmp[32] = {}, cnt = 0;
    const u64 phi = P - 1;
    u64 m = phi;

    for (u64 i = 2; i * i <= m; ++i) {
      if (m % i == 0) {
        tmp[cnt++] = i;

        while (m % i == 0)
          m /= i;
      }
    }

    if (m > 1)
      tmp[cnt++] = m;

    for (u64 res = 2; res <= phi; ++res) {
      bool flag = true;

      for (u32 i = 0; i != cnt && flag; ++i)
        flag &= pow_mod(res, phi / tmp[i]) != 1;

      if (flag)
        return res;
    }

    return 0;
  }

  MontgomeryModInt32() = default;
  ~MontgomeryModInt32() = default;
  constexpr MontgomeryModInt32(u32 v)
      : v(reduce(u64(v) * r2)) {}
  constexpr MontgomeryModInt32(
      const MontgomeryModInt32 &rhs)
      : v(rhs.v) {}
  static constexpr u32 reduce(u64 x) {
    return x + (u64(u32(x) * r) * P) >> 32;
  }
  constexpr u32 get() const {
    u32 res = reduce(v);
    return res - (P & -(res >= P));
  }
  explicit constexpr operator u32() const { return get(); }
  explicit constexpr operator i32() const {
    return i32(get());
  }
  constexpr MontgomeryModInt32 &
  operator=(const MontgomeryModInt32 &rhs) {
    return v = rhs.v, *this;
  }
  constexpr MontgomeryModInt32 operator-() const {
    MontgomeryModInt32 res;
    return res.v = (P << 1 & -(v != 0)) - v, res;
  }
  constexpr MontgomeryModInt32 inv() const {
    return pow(-1);
  }
  constexpr MontgomeryModInt32 &
  operator+=(const MontgomeryModInt32 &rhs) {
    return v += rhs.v - (P << 1),
           v += P << 1 & -(i32(v) < 0), *this;
  }
  constexpr MontgomeryModInt32 &
  operator-=(const MontgomeryModInt32 &rhs) {
    return v -= rhs.v, v += P << 1 & -(i32(v) < 0), *this;
  }
  constexpr MontgomeryModInt32 &
  operator*=(const MontgomeryModInt32 &rhs) {
    return v = reduce(u64(v) * rhs.v), *this;
  }
  constexpr MontgomeryModInt32 &
  operator/=(const MontgomeryModInt32 &rhs) {
    return this->operator*=(rhs.inv());
  }
  friend MontgomeryModInt32
  operator+(const MontgomeryModInt32 &lhs,
            const MontgomeryModInt32 &rhs) {
    return MontgomeryModInt32(lhs) += rhs;
  }
  friend MontgomeryModInt32
  operator-(const MontgomeryModInt32 &lhs,
            const MontgomeryModInt32 &rhs) {
    return MontgomeryModInt32(lhs) -= rhs;
  }
  friend MontgomeryModInt32
  operator*(const MontgomeryModInt32 &lhs,
            const MontgomeryModInt32 &rhs) {
    return MontgomeryModInt32(lhs) *= rhs;
  }
  friend MontgomeryModInt32
  operator/(const MontgomeryModInt32 &lhs,
            const MontgomeryModInt32 &rhs) {
    return MontgomeryModInt32(lhs) /= rhs;
  }
  friend std::istream &operator>>(std::istream &is,
                                  MontgomeryModInt32 &rhs) {
    return is >> rhs.v, rhs.v = reduce(u64(rhs.v) * r2), is;
  }
  friend std::ostream &
  operator<<(std::ostream &os,
             const MontgomeryModInt32 &rhs) {
    return os << rhs.get();
  }
  constexpr MontgomeryModInt32 pow(i64 y) const {
    if ((y %= P - 1) < 0)
      y += P - 1;
    MontgomeryModInt32 res(1), x(*this);
    for (; y != 0; y >>= 1, x *= x)
      if (y & 1)
        res *= x;
    return res;
  }
};

::::

觉得卡常数有用可以点个赞吗喵喵。