P1495 【模板】中国剩余定理(CRT)/ 曹冲养猪 题解

· · 题解

\large\color{00aacd}\textbf{中国剩余定理(CRT)}

前置知识:同余、逆元、扩展欧几里得定理。

\color{00cd00}\text{算法介绍}

中国剩余定理用于求解如下形式的同余方程组:

\begin{cases} x\equiv b_1 \pmod{a_1} \\ x\equiv b_2 \pmod{a_2} \\ \dots \\ x\equiv b_n \pmod{a_n} \end{cases}

其中,a_1, a_2, \dots a_n 两两互质。

求解的过程如下:

  1. 计算所有模数的积 M = \prod\limits_{i=1}^n a_i
  2. 依次考虑每一个方程。对于第 i 个方程:
    • 计算 m_i = M \div a_i
    • 计算 m_i 在模 a_i 意义下的逆元 m_i^{-1}
      • 因为 a_1, a_2, \dots, a_n 两两互质,显然有 \gcd(m_i, a_i) = 1,即逆元一定存在。
  3. 最终,x 最小的非负整数解为 x = (\sum\limits_{i=1}^n m_i\cdot m_i^{-1}\cdot b_i) \bmod M

\color{00cd00}\text{正确性证明}

我们需要证明,以上算法所得到的 x 满足 \forall i \in [1, n],\ x\equiv b_i \pmod{a_i}

根据以上 m_i 的定义,可以得到当 i\ne j 时,m_j\equiv 0 \pmod{a_i}

又根据逆元的定义,可以得到 m_i \cdot m_i^{-1} \equiv 1 \pmod{a_i}

因此:

\begin{aligned} x &\equiv \sum_{j=1}^n m_j\cdot m_j^{-1}\cdot b_j &\pmod{a_i} \\ &\equiv m_i\cdot m_i^{-1}\cdot b_i &\pmod{a_i} \\ &\equiv b_i &\pmod{a_i} \end{aligned}

即:\forall i \in [1, n],\ x\equiv b_i \pmod{a_i}。算法的正确性得证。

\color{00cd00}\text{代码实现}

由于 a_i 不一定是质数,不能用费马小定理求逆元,所以用了扩展欧几里得定理来求。

C++:

本题运算过程中可能会爆 long long,需要使用 __int128 才能通过。

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 15;
int n, a[N], b[N];
pair<int, int> exgcd(int a, int b){
    if(b == 0) return {1, 0};
    auto [x, y] = exgcd(b, a % b);
    return {y, x - a / b * y};
}
int inv(int a, int p){
    auto [x, y] = exgcd(a, p);
    return (x + p) % p; //exgcd 求出的解有可能是负的,所以取模一下转成正的
}
int CRT(int n, int *a, int *b){
    int M = 1, x = 0;
    for(int i=1; i<=n; i++) M *= a[i];
    for(int i=1; i<=n; i++){
        __int128 Mi = M / a[i], Ti = inv(Mi, a[i]);
        x = (x + Mi * Ti * b[i] % M) % M;
    }
    return x;
}
signed main(){
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin >> n;
    for(int i=1; i<=n; i++) cin >> a[i] >> b[i];
    cout << CRT(n, a, b);
    return 0;
}

Python:

def exgcd(a, b):
    if b == 0: return (1, 0)
    x, y = exgcd(b, a % b)
    return y, x - a // b * y

def inv(a, p): return exgcd(a, p)[0] % p

def CRT(n, a, b):
    M = 1; x = 0
    for i in range(n): M *= a[i]
    for i in range(n):
        Mi = M // a[i]; Ti = inv(Mi, a[i])
        x += Mi * Ti * b[i]
    return x % M

def main():
    n = int(input())
    a = [0] * n; b = [0] * n
    for i in range(n):
        a[i], b[i] = map(int, input().split())
    print(CRT(n, a, b))

if __name__ == '__main__':
    main()