题解:P1495 【模板】中国剩余定理(CRT)/ 曹冲养猪

· · 题解

P1495 【模板】中国剩余定理(CRT)/ 曹冲养猪 题解

题意

试求如下同余方程组的最小非负整数解:

\begin{cases} x\equiv a_1\pmod {b_1}\\ x\equiv a_2\pmod {b_2}\\ \cdots\\ x\equiv a_n\pmod {b_n}\\ \end{cases}

其中,b_i 两两互质。

思路

对于这样模数两两互质的题,可以使用中国剩余定理来求解。什么是中国剩余定理呢?请听我慢慢道来……

首先,由余数的可加性可知,若可以求得以下 n 个方程的解,那便可以求出原方程的解:

\begin{cases} x_1\equiv a_1\pmod {b_1}\\ x_1\equiv 0\pmod {b_2}\\ x_1\equiv 0\pmod {b_3}\\ \cdots\\ x_1\equiv 0\pmod {b_n}\\ \end{cases} \begin{cases} x_2\equiv 0\pmod {b_1}\\ x_2\equiv a_2\pmod {b_2}\\ x_2\equiv 0\pmod {b_3}\\ \cdots\\ x_2\equiv 0\pmod {b_n}\\ \end{cases} \cdots \begin{cases} x_n\equiv 0\pmod {b_1}\\ x_n\equiv 0\pmod {b_2}\\ x_n\equiv 0\pmod {b_3}\\ \cdots\\ x_n\equiv a_n\pmod {b_n}\\ \end{cases}

此时,便有原方程的解 x=x_1+x_2+\cdots+x_n

现在,考虑将每一个方程组拿出来单独求解。我拿第一个方程组讲解:

\begin{cases} x_1\equiv a_1\pmod {b_1}\\ x_1\equiv 0\pmod {b_2}\\ x_1\equiv 0\pmod {b_3}\\ \cdots\\ x_1\equiv 0\pmod {b_n}\\ \end{cases}

要求出如上这个方程的解,其实可以再做一步转换,先求出以下这个方程的解:

\begin{cases} y_1\equiv 1\pmod {b_1}\\ y_1\equiv 0\pmod {b_2}\\ y_1\equiv 0\pmod {b_3}\\ \cdots\\ y_1\equiv 0\pmod {b_n}\\ \end{cases}

由余数的可乘性得:x_1=y_1\times a_1。同时,观察方程组可知 y_1 一定是 \prod_{i=2}^{n} b_i 的倍数。设 \prod_{i=2}^{n} b_im_1,则有 y_1=m_1\times k_1。因为 y_1 还要满足第一个同余方程,所以此时得到:m_1\times k_1\equiv 1\pmod {b_1}。于是,求 y_1 的问题就转换成了求 m_1\bmod b_1 的逆元的问题,可以用扩展欧几里得算法求解。

于是便可以按这种方法,求出上面的 n 个方程组的解,然后便得到了原方程的整数解。不过还要再进行对所有的数的乘积的取模,因为要求最小非负整数解

所以,设所有的数的乘积为 k,则最终答案用式子表达便是:

x=\sum_{i=1}^{n}k_i\times a_i\times m_i\pmod k

代码

开了 __int128

#include<bits/stdc++.h>
#define int __int128
using namespace std;
const int MAXN = 1e5 + 10;
int n;
int a[MAXN] , b[MAXN];
int s;

int exgcd(int a , int b , int &x , int &y) {
    if(!b) {
        x = 1;
        y = 0;
        return a;
    }
    int r = exgcd(b , a % b , x , y);
    int t = x;
    x = y;
    y = t - a / b * y;
    return r;
}

inline int read() {
    char c = getchar();
    int x = 0 , s = 1;
    while(c < '0' || c > '9') {
        if(c == '-')
            s = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x * s;
}

void write(int x){
    if(x > 9)
        write(x/10);
    putchar(x % 10 | 48);
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    n = read();
    int ans = 1;
    for(register int i = 1;i <= n;i ++) {
        a[i] = read();
        b[i] = read();
        ans *= a[i];
    }
    for(register int i = 1;i <= n;i ++) {
        int k = ans / a[i];
        int x , y;
        exgcd(k , a[i] , x , y);
        s = s +  k * b[i] * x % ans;
    }
    write((s % ans + ans) % ans);
    return 0;
}