题解:AT_abc443_g [ABC443G] Another Mod of Linear Problem

· · 题解

abc443-g 题解

前言

为什么大家都用 atcoder 自带的写,这显得我是小丑诶。

终于 AK 了。

思路

转化一下:

\begin{aligned} &\phantom{\iff} k < (Ak+B) \bmod M\\ &\iff k+1 \le Ak+B-M\left\lfloor\frac{Ak+B}M \right\rfloor \\ &\iff \left\lfloor\frac{Ak+B}M \right\rfloor \le \frac{(A-1)k+(B-1)}M\\ &\iff \left\lfloor\frac{Ak+B}M \right\rfloor \le \left\lfloor\frac{(A-1)k+(B-1)}M\right\rfloor\\ \end{aligned}

注意到 \displaystyle \left\lfloor\frac{Ak+B}M \right\rfloor-\left\lfloor\frac{(A-1)k+(B-1)}M \right\rfloor 的值只可能是 0 或 1。

只要求 \displaystyle N-\sum_{k=0}^{N-1} \left(\left\lfloor\frac{Ak+B}M \right\rfloor-\left\lfloor\frac{(A-1)k+(B-1)}M \right\rfloor\right) 即可。

运用 floor sum 算法即可优化至对数时间复杂度。

code

#include <bits/stdc++.h>
using namespace std;

#define ll __int128

ll exgcd(ll a, ll b, ll &x, ll &y)
{
    if (b == 0)
    {
        x = (a >= 0 ? 1 : -1);
        y = 0;
        return llabs(a);
    }

    ll x1, y1;
    ll g = exgcd(b, a % b, x1, y1);
    x = y1;
    y = x1 - (a / b) * y1;
    return g;
}

ll getsum(ll n, ll m, ll a, ll b)
{
    ll ans = 0;

    while (1)
    {
        if (a >= m)
        {
            ll q = a / m;
            ans += (n - 1) * n / 2 * q;
            a %= m;
        }

        if (b >= m)
        {
            ll q = b / m;
            ans += n * q;
            b %= m;
        }

        ll y = a * n + b;
        if (y < m)
            break;

        n = y / m;
        b = y % m;
        swap(a, m);
    }

    return ans;
}

void please_ac()
{
    long long n, m, a, b;
    cin >> n >> m >> a >> b;

    if (a == 0)
    {
        ll r = b % m;
        long long ans = min((ll)n, r);

        cout << ans << "\n";
        return;
    }

    ll c = a - 1;
    ll sa = getsum(n, m, a, b), sc = getsum(n, m, c, b), s2 = sa - sc;

    ll cnt0 = 0;
    if (c == 0)
    {
        if ((b % m) == 0)
            cnt0 = n;
        else
            cnt0 = 0;
    }
    else
    {
        ll x, y;
        ll g = __gcd(llabs(c), m);

        if ((b % g) != 0)
            cnt0 = 0;
        else
        {
            ll c_ = c / g;
            ll m_ = m / g;

            ll rhs = ((-b / g) % m_ + m_) % m_;

            ll x_, y_;
            ll gc = exgcd(c_, m_, x_, y_);

            ll inv = (x_ % m_ + m_) % m_;
            ll k0 = ((__int128)rhs * inv) % m_;

            if (k0 < n)
                cnt0 = 1 + (n - 1 - k0) / m_;
            else
                cnt0 = 0;
        }
    }

    ll s1 = n - cnt0;

    long long ans = max((ll)0, s1 - s2);
    cout << ans << "\n";
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);

    int T_T = 1;
    cin >> T_T;
    while (T_T--)
        please_ac();
}