题解:P12045 [USTCPC 2025] 一道数学题

· · 题解

基础情况分析

首先,我们分析当 n = \frac{m \times (m + 1)}{2} 中最小值的情况,使用以下代码暴力求出前 10 个最小值:

#include <bits/stdc++.h>
using namespace std;

int main() {
    for (int i = 2; i <= 11; ++i) {
        int mn = INT_MAX;
        vector<int> a(i);
        for (int j = 0; j < i; ++j) a[j] = j + 1;

        do {
            int tmp = a[0] * a[i - 1];
            for (int j = 1; j < i; ++j) tmp += a[j] * a[j - 1];
            mn = min(mn, tmp);
        } while (next_permutation(a.begin(), a.end()));

        cout << mn << ' ';
    }
    return 0;
}

求得前 10 个最小值分别为:

4 11 21 37 58 87 123 169 224 291

此时,我们对它进行三阶差分。

0️⃣ 1️⃣ 2️⃣ 3️⃣ 4️⃣ 5️⃣ 6️⃣ 7️⃣ 8️⃣ 9️⃣ ...
原数列 4 11 21 37 58 87 123 169 224 291 ...
\Delta^1 7 10 16 21 29 36 46 55 67 ...
\Delta^2 3 6 5 8 7 10 9 12 ...
\Delta^3 3 -1 3 -1 3 -1 3 ...

可以看到,第三阶的值在 3-1 间交替变化。

拓展分析

接下来拓展到 n > \frac{m \times (m + 1)}{2} 的情况。

很容易可以看到,当把多出来的值分配给 a_i 时,对总值的贡献是 a_i \times (a_{i-1} + a_{i+1})

进而发现,n 每增加 1 的时候,对总值最小的贡献是 3 (当 a_i 的两边分别为 12 的时候)

注意: 此时还需特判 n = 2 的情况。

if (m == 2) {
    cout << (n - 1 + n - 1) % MOD << endl;
    return;
}

完整代码

#include <bits/stdc++.h>

#define LL long long
#define endl '\n'

using namespace std;

const LL MOD = 998244353;
LL d[4][200010], tmp[] = {-1, 3};

void solve() {
    LL n, m;
    cin >> m >> n;

    if (m == 2) {
        cout << (n - 1 + n - 1) % MOD << endl;
        return;
    }

    for (int i = 1; i < 200005; ++i) d[3][i] = tmp[i & 1];
    d[2][1] = 3;
    d[1][1] = 7;
    d[0][1] = 4;
    for (int i = 1; i < 200005; ++i) {
        d[2][i + 1] = (d[2][i] + d[3][i]) % MOD;
        d[1][i + 1] = (d[1][i] + d[2][i]) % MOD;
        d[0][i + 1] = (d[0][i] + d[1][i]) % MOD;
    }
    n -= m * (m + 1) / 2ll;
    cout << (d[0][m - 1] + 3ll * n) % MOD << endl;
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    int T = 1;
    //cin >> T;
    while (T--) solve();
    return 0;
}