题解:P12047 [USTCPC 2025] 翻转数字

· · 题解

有幸成为 USTCPC(包括线上、线下、锣鼓重现赛)这题的唯一通过者,因此来贴一下之前写好的题解。比赛完整题解之后会发布在我的 cnblogs。

当之无愧的防 AK 题,需要用到一连串的结论,用 Python 验证一下会更放心。

首先考虑操作可达什么串。当没有对称字符的概念时,经典的结论是奇偶下标可以分别乱序排列,但奇数下标不能和偶数下标互换。有对称字符概念时结论略有变化,只有非对称的字符下标奇偶性不能变,其余字符可以任意排列。

那么对于一个子串的答案,我们就有了一个贪心算法:首先 9 显然一定可以(在对应的奇偶位置)尽量往前填,之后 8 在能够形成合法串(即不占了 5 的位置)的前提下也能尽量往前填。接下来仍然 5 在对应的奇偶位置上贪心填;最后 1 和 0 都是对称的,一定是 1 在前面,0 在后面。这个算法可以 O(n) 实现,那么我们就有了 O(n^3) 的解法。多项式复杂度毫无疑问是一个很好的起点。

事实上我们可以继续优化到 O(n^2)。关键的观察是固定左端点移动右端点时,可以 O(1)(或者说 O(\Sigma)?) 地增量更新当前的最优解和答案。虽然我没有实现完全一样的算法(甚至利用了随机性质),第一天(做 N 题的第一天,也就是第三天)的解法就停留在了这个复杂度,直到我发现 n=50000 而不是 5000 所以 c\times5n^2 的计算量显然是过不了的。

但我们目前还没有利用随机性质呢。事实上目前算法竞赛题对随机性质的利用都很老套,无非就是大数定律那一套,两个概率一样的东西在 n 比较大的时候不会相差很多,概率不一样的东西在 n 比较大的时候一定会相差很多……这一题就很容易观察出:等概率随机,说明对于较长的串,9、8、5、1、0 的数量相差不大,那么最优串的形态一定类似于 999...8989888...555...1515111...00000,也就是说 9/8, 5/1, 0 之间分割明显,但由于奇偶的少量差距,8 的开头会有少量的同奇偶的 9,1 的开头有少量同奇偶的 5。

所以我们的做法肯定是对于短串先暴力算,对于长串利用这个性质去快速算。经过 Python 蒙特卡洛估计,当分界线取 64 时,概率已经在 10^{-7} 量级了,但由于我们有 n^2 个子串,所以错误概率要限制在 \omega(n^{-2}),取 160 可以通过此题。事实上错误概率一般都是 O(\text{poly}(e^{-O(\text{poly}(n))})) 的,所以这个阈值一般来说是 O(\text{poly}(\log n))的。

事实上我第一天的想法就是利用这个性质去做增量更新然后 O(n^2),在失败后,当晚睡着之前我就一直在想如何简化整个串的所有子串的答案计算,然后我发现是可以简化到 O(n) 的。我把整个思维过程陈述如下:

\begin{align} \text{ans} &= \sum_{i=1}^n \sum_{l=1}^i \sum_{r=i}^n [a_i=1] 10^{\text{cnt\_right\_10} + \text{cnt\_left\_0}} \\ & = \sum_{i=1}^n \sum_{l=1}^i \sum_{r=i}^n [a_i=1] 10^{\text{pres\_cnt\_right\_10}_l - \text{pres\_cnt\_right\_10}_i + \text{suf\_cnt\_left\_0}_r - \text{suf\_cnt\_left\_0}_i} \\ &= \sum_{i=1}^n [a_i=1] 10^{-f(i)}(\sum_{l=1}^i 10^{f(l)})(\sum_{r=i}^n 10^{f(r)}) \end{align}

发现了吗?我们把一个位置的贡献拆成了两项可以预处理的求和之积,这样就 O(n) 解决了这个问题!并且这个思路推广到只有 018 的串是毫无压力的。

\begin{aligned} ans_i &= \sum_{l=1}^i \sum_{r=i}^n 10^{r - l - 2 \text{left\_sameparity\_parity9\_cnt} -((i - l) \bmod 2)} \\ &= \sum_{l=1}^i \sum_{r=i}^n 10^{r - l - 2 \text{suf\_cnt9\_parity}_{i\bmod2, l} + 2 \text{suf\_cnt9\_parity}_{i\bmod2, i} -((i - l) \bmod 2)} \\ &= f(i) (\sum_{l=1}^i 10^{-l - f(l) - ((i - l) \bmod 2)})(\sum_{r=i}^n 10^r) \end{aligned}

我们发现 l 的式子里带上了 i,这是我们不希望的;但是只要对 l 分奇偶预处理和,再加起来就可以解决。于是我们化解了来自 9 的挑战。大概想到这里我就睡着了。

第二天起来就开始细化解法的同时写代码。然而写到 5 的时候我发现事情好像不那么简单,式子中涉及了同时与 i, l, r 有关的项……我差点以为三百多行的代码就要交代在这了,仔细考虑后发现,只要按照 (i + \text{suf\_cnt\_8}_{r + 1}) + (l + \text{suf\_cnt\_8}_l) \mod 2 分类,多做一点类内预处理就好了。于是我这里就略过这最复杂的一部分的推导了(代码中有写这一部分的式子……)。

总之,我们先根据上面的线性做法计算整个串符合假设时的答案,而对于短串暴力计算答案并减去可能错误计算的答案,整个题目就在 O(n B^2 + n) = O(n \ \text{poly log}(n)) 的时间内完成了。代码很难写,非常需要耐心。据说 std 是更好写的写法。

完整代码:

#include <bits/stdc++.h>
using namespace std;
using i64 = int64_t;
using u64 = uint64_t;
constexpr i64 MOD = 998244353, N = 5e4 + 500 + 10, B = 128 + 32;
i64 p10[N + 1]{}, ip10[N + 1]{}, m11[N + 1]{}, m1010[N + 1]{};
i64 sorted_arr[N + 1]{}, a[N + 1]{};
i64 inv10, inv100;

i64 p10_suf_cnt0[N + 1]{}, pre_p10_suf_cnt0[N + 1]{};
i64 p10_pre_cnt01[N + 1]{}, suf_p10_pre_cnt01[N + 1]{};

i64 p10_suf_cnt01[N + 1]{}, pre_p10_suf_cnt01[N + 1]{};
i64 p10_pre_cntall[N + 1]{}, suf_p10_pre_cntall[N + 1]{};

i64 suf_p10[N + 1]{};
i64 pi100_suf_par_cnt9[2][N + 1]{}, pre_par_pi100_suf_par_cnt9[2][2][N + 1]{};

i64 par_suf_p10[2][N + 1]{};
i64 suf_cnt8[N + 1]{}, p10_suf_cnt8[N + 1]{}, pre_p10_suf_cnt8[N + 1]{};
i64 pi100_suf_par_cnt5[2][N + 1]{}, pre_par_pi100_suf_par_cnt5[2][2][N + 1]{};

using CountType = array<array<int, 10>, 2>;

i64 modpow(i64 a, i64 b, i64 M) {
    i64 res = 1;
    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return res;
}

i64 modinv(i64 a) { return modpow(a, MOD - 2, MOD); }

void pre_calc() {
    fill(a, a + N + 1, 0);
    inv10 = modinv(10);
    p10[0] = ip10[0] = 1;
    for (int i = 1; i <= N; i++) {
        p10[i] = (p10[i - 1] * 10) % MOD;
        ip10[i] = (ip10[i - 1] * inv10) % MOD;
    }
    m11[0] = 0;
    for (int i = 1; i <= N; i++) {
        // 11111... (i 1s)
        m11[i] = (m11[i - 1] * 10 + 1) % MOD;
    }
    inv100 = modinv(100);
    m1010[1] = 1;
    for (int i = 1; i <= N; i++) {
        // 1.01010101... (i 1s)
        m1010[i] = (m1010[i - 1] * inv100 + 1) % MOD;
    }
}

i64 brute_calc(const CountType &cnt, int n) {
    fill(sorted_arr, sorted_arr + n, -1);
    array<int, 2> rem = {(n + 1) / 2, n / 2};
    int p0 = 0, p1 = 1, p = 0;
    for (int i = 0; i < cnt[0][9]; i++) {
        sorted_arr[p0] = 9;
        p0 += 2;
        rem[0]--;
    }
    for (int i = 0; i < cnt[1][9]; i++) {
        sorted_arr[p1] = 9;
        p1 += 2;
        rem[1]--;
    }
    p = min(p0, p1);
    for (int i = 0; i < cnt[0][8] + cnt[1][8]; i++) {
        while (true) {
            while (sorted_arr[p] != -1) p++;
            bool feasible = rem[p % 2] > cnt[p % 2][5];
            if (feasible) {
                sorted_arr[p] = 8;
                rem[p % 2]--;
                p++;
                break;
            }
            p++;
        }
    }
    for (int i = 0; i < cnt[0][5]; i++) {
        while (sorted_arr[p0] != -1) p0 += 2;
        sorted_arr[p0] = 5;
        p0 += 2;
    }
    for (int i = 0; i < cnt[1][5]; i++) {
        while (sorted_arr[p1] != -1) p1 += 2;
        sorted_arr[p1] = 5;
        p1 += 2;
    }
    for (int i = 0; i < cnt[0][1] + cnt[1][1]; i++) {
        while (sorted_arr[p] != -1) p++;
        sorted_arr[p] = 1;
        p++;
    }
    for (int i = 0; i < cnt[0][0] + cnt[1][0]; i++) {
        while (sorted_arr[p] != -1) p++;
        sorted_arr[p] = 0;
        p++;
    }

    i64 res = 0;
    for (int i = 0; i < n; i++) {
        if (sorted_arr[i] == -1) {
            // for (int i = 0; i < n; i++) {
            //     cout << sorted_arr[i] << ", ";
            // }
            // cout << endl;
            // for (int par: {0, 1})
            //     for (int x: {0, 1, 5, 8, 9}) {
            //         cout << par << " " << x << " " << cnt[par][x] << endl;
            //     }
            exit(0);
        }
        res += p10[n - i - 1] * sorted_arr[i];
    }
    return res % MOD;
}

// This is O(1) but with huge constant
i64 errornous_calc(const CountType &cnt, int l, int r) {
    // first calc 8...81...10...0
    i64 c8 = cnt[0][8] + cnt[1][8] + cnt[0][9] + cnt[1][9],
        c1 = cnt[0][1] + cnt[1][1] + cnt[0][5] + cnt[1][5], c0 = cnt[0][0] + cnt[1][0];
    i64 vl = 8 * m11[c8], vm = m11[c1];
    i64 bm = p10[c1], br = p10[c0];
    i64 res = (vl * bm + vm) % MOD * br % MOD;
    // then add 9 and 5
    res += p10[r - l] * m1010[cnt[0][9]] % MOD; // 1.0101... * 10^(r - l)
    if (r > l) {
        res += p10[r - l - 1] * m1010[cnt[1][9]] % MOD;
    } else {
        res += ip10[l + 1 - r] * m1010[cnt[1][9]] % MOD;
    }
    i64 l5 = l + c8;
    res += p10[r - l5] * m1010[cnt[c8 % 2][5]] * 4 % MOD; // 5 - 1
    if (r > l5) {
        res += p10[r - l5 - 1] * m1010[cnt[(c8 + 1) % 2][5]] * 4 % MOD;
    } else {
        res += ip10[l5 + 1 - r] * m1010[cnt[(c8 + 1) % 2][5]] * 4 % MOD;
    }
    return res % MOD;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    pre_calc();

    string s; cin >> s;
    int n = s.size();

    // mt19937_64 rng(11451419198);
    // int n = 5000;
    // string s(n, '6');
    // for (int i = 0; i < n; i++) {
    //     s[i] = "19580"[rng() % 5];
    // }
    for (int i = 1; i <= n; i++) {
        a[i] = s[i - 1] - '0';
    }

    // CountType cnt{};
    // for (int i = 1; i <= n; i++) {
    //     cnt[i % 2][a[i]]++;
    // }
    // cout << errornous_calc(cnt, 1, n);
    // return 0;

    i64 ans = 0;
    for (int l = 1; l <= n; l++) {
        CountType cnt{};
        i64 now_ans = 0;
        for (int r = l; r <= l + B && r <= n; r++) {
            cnt[(r - l) % 2][a[r]]++;
            now_ans = brute_calc(cnt, r - l + 1) % MOD;
            now_ans += (MOD - errornous_calc(cnt, l, r)) % MOD;
            // cout << l << " " << r << " " << errornous_calc(cnt, l, r) << "\n";
            ans = (ans + now_ans) % MOD;
        }
        ans %= MOD;
        if (l + B >= n) continue;
    }
    // cout << (ans % MOD) << ": less than B\n";

    // we need some precalc first
    // for example, the contribution of one 1 is
    //   sum_l sum_r 10^ {cnt0_left + cnt01_right}
    // = sum_l sum_r 10 ^ {cnt0_suf_l - cnt0_suf_i + cnt01_pre_r - cnt01_pre_i}
    // = sum_l (10 ^ cnt0_suf_l) * sum_r (10 ^ cnt01_pre_r) * 10 ^ (- cnt0_suf_i - cnt01_pre_i)
    p10_suf_cnt0[n + 1] = 1;
    for (int i = n; i >= 1; i--) {
        p10_suf_cnt0[i] = p10_suf_cnt0[i + 1];
        if (a[i] == 0) (p10_suf_cnt0[i] *= 10) %= MOD;
    }
    for (int i = 1; i <= n; i++) {
        (pre_p10_suf_cnt0[i] = pre_p10_suf_cnt0[i - 1] + p10_suf_cnt0[i]) %= MOD;
    }
    p10_pre_cnt01[0] = 1;
    for (int i = 1; i <= n; i++) {
        p10_pre_cnt01[i] = p10_pre_cnt01[i - 1];
        if (a[i] == 0 || a[i] == 1 || a[i] == 5) (p10_pre_cnt01[i] *= 10) %= MOD;
    }
    for (int i = n; i >= 1; i--) {
        (suf_p10_pre_cnt01[i] = suf_p10_pre_cnt01[i + 1] + p10_pre_cnt01[i]) %= MOD;
    }

    i64 ans1 = 0;
    for (int i = 1; i <= n; i++) {
        if (a[i] == 1 || a[i] == 5) {
            ans1 += pre_p10_suf_cnt0[i] * suf_p10_pre_cnt01[i] % MOD *
                modinv(p10_suf_cnt0[i]) % MOD * modinv(p10_pre_cnt01[i]) % MOD;
            // cout << i << " " << p10_suf_cnt0[i] << " " << pre_p10_suf_cnt0[i] << endl;
            // cout << i << "[ans1]" << pre_p10_suf_cnt0[i] * suf_p10_pre_cnt01[i] % MOD *
            //     modinv(p10_suf_cnt0[i]) % MOD * modinv(p10_pre_cnt01[i]) % MOD << endl;
            ans1 %= MOD;
        }
    }
    (ans += ans1) %= MOD;
    // cout << ans1 << ": ans1" << endl;

    // for 8 we have the same logic.
    // sum_l (10 ^ cnt01_suf_l) * sum_r (10 ^ cnt_all_pre_r) * 10 ^ (- cnt01_suf_i - cnt_all_pre_i)
    p10_suf_cnt01[n + 1] = 1;
    for (int i = n; i >= 1; i--) {
        p10_suf_cnt01[i] = p10_suf_cnt01[i + 1];
        if (a[i] == 0 || a[i] == 1 || a[i] == 5) (p10_suf_cnt01[i] *= 10) %= MOD;
    }
    for (int i = 1; i <= n; i++) {
        (pre_p10_suf_cnt01[i] = pre_p10_suf_cnt01[i - 1] + p10_suf_cnt01[i]) %= MOD;
    }
    p10_pre_cntall[0] = 1;
    for (int i = 1; i <= n; i++) {
        p10_pre_cntall[i] = p10_pre_cntall[i - 1];
        (p10_pre_cntall[i] *= 10) %= MOD;
    }
    for (int i = n; i >= 1; i--) {
        (suf_p10_pre_cntall[i] = suf_p10_pre_cntall[i + 1] + p10_pre_cntall[i]) %= MOD;
    }
    i64 ans8 = 0;
     for (int i = 1; i <= n; i++) {
        if (a[i] == 9 || a[i] == 8) {
            ans8 += 8 * pre_p10_suf_cnt01[i] * suf_p10_pre_cntall[i] % MOD *
                modinv(p10_suf_cnt01[i]) % MOD * modinv(p10_pre_cntall[i]) % MOD;
            ans8 %= MOD;
        }
    }
    // cout << ans8 << ": ans8" << endl;

    (ans += ans8) %= MOD;

    // for 9 it's harder...
    // if 9 and l have the same parity:
    //   sum_l sum_r 10 ^ {r - l - 2 * cnt_left[i % 2][9])}
    // = sum_l sum_r 10 ^ {r - l - 2 * suf_cnt9_[i % 2]_l + 2 * suf_cnt9_[i % 2]_i}
    // = (sum_l_{same parity} 10 ^ {-l - 2 * suf_cnt9_[i % 2]_l}) * (sum_r 10^r) * 10 ^ {2 * suf_cnt9_[i % 2]_i}
    // otherwise:
    //   (sum_l_{diff parity} 10 ^ {-l - 2 * suf_cnt9_[i % 2]_l}) * (sum_r 10^r) / 10 * 10 ^ {2 * suf_cnt9_[i % 2]_i}
    for (int i = n; i >= 1; i--) {
        suf_p10[i] = (suf_p10[i + 1] + p10[i]) % MOD;
    }
    for (int par: {0, 1}) {
        pi100_suf_par_cnt9[par][n + 1] = 1;
        for (int i = n; i >= 1; i--) {
            pi100_suf_par_cnt9[par][i] = pi100_suf_par_cnt9[par][i + 1];
            if (i % 2 == par && a[i] == 9) {
                (pi100_suf_par_cnt9[par][i] *= inv100) %= MOD;
            }
        }
        for (int sum_par: {0, 1}) {
            for (int i = 1; i <= n; i++) {
                pre_par_pi100_suf_par_cnt9[sum_par][par][i] = pre_par_pi100_suf_par_cnt9[sum_par][par][i - 1];
                if (i % 2 == sum_par) {
                    (pre_par_pi100_suf_par_cnt9[sum_par][par][i] += pi100_suf_par_cnt9[par][i] * ip10[i] % MOD) %= MOD;
                }
            }
        }
    }

    i64 ans9 = 0;
    for (int i = 1; i <= n; i++) {
        if (a[i] == 9) {
            ans9 += pre_par_pi100_suf_par_cnt9[i % 2][i % 2][i] * suf_p10[i] % MOD
                * modinv(pi100_suf_par_cnt9[i % 2][i]) % MOD;
            ans9 += pre_par_pi100_suf_par_cnt9[(i % 2) ^ 1][i % 2][i] * suf_p10[i] % MOD * inv10 % MOD
                * modinv(pi100_suf_par_cnt9[i % 2][i]) % MOD;
            ans9 %= MOD;
            // cout << i << " " << ans9 << "\n";
        }
    }
    (ans += ans9) %= MOD;
    // cout << ans9 << ": ans9" << endl;

    // The most difficult part [panic]
    // The contribution of 5
    // if i and (l - cnt8) have the same parity:
    //   sum_l sum_r 10 ^ {r - l - cnt8 - 2 * cnt_left[i % 2][5])}
    // = sum_l sum_r 10 ^ {r - l - suf_cnt8_l - 2 * suf_cnt5_[i % 2]_l + 2 * suf_cnt5_[i % 2]_i + suf_cnt8_(r+1))}
    // = (sum_l_{same parity} 10 ^ {-l - suf_cnt8_l - 2 * suf_cnt5_[i % 2]_l}) * (sum_r_{same parity} 10^ {r + suf_cnt8_(r+1)}) * 100 ^ suf_cnt5_[i % 2]_i

    // otherwise just / 10

    // How to determine the parity? (i + suf_cnt_8[r + 1]) and (l + suf_cnt_8[l]) should have the same parity.
    // It doesn't seem as hard as it looks.

    suf_cnt8[n + 1] = 0; p10_suf_cnt8[n + 1] = 1;
    for (int i = n; i >= 1; i--) {
        suf_cnt8[i] = suf_cnt8[i + 1]; p10_suf_cnt8[i] = p10_suf_cnt8[i + 1];
        if (a[i] == 9 || a[i] == 8) {
            suf_cnt8[i] += 1;
            (p10_suf_cnt8[i] *= 10) %= MOD;
        }
    }
    pre_p10_suf_cnt8[0] = 0;
    for (int i = 1; i <= n; i++) {
        (pre_p10_suf_cnt8[i] = pre_p10_suf_cnt8[i - 1] + p10_suf_cnt8[i]) %= MOD;
    }

    for (int par: {0, 1}) {
        for (int i = n; i >= 1; i--) {
            par_suf_p10[par][i] = par_suf_p10[par][i + 1];
            if (suf_cnt8[(i + 1)] % 2 == par) {
                // sum_r_{same parity} 10 ^ {r + suf_cnt8_(r+1)}
                (par_suf_p10[par][i] += p10[i] * p10_suf_cnt8[i + 1] % MOD) %= MOD;
            }
        }
    }

    for (int par: {0, 1}) {
        pi100_suf_par_cnt5[par][n + 1] = 1;
        for (int i = n; i >= 1; i--) {
            pi100_suf_par_cnt5[par][i] = pi100_suf_par_cnt5[par][i + 1];
            if (i % 2 == par && a[i] == 5) {
                (pi100_suf_par_cnt5[par][i] *= inv100) %= MOD;
            }
        }
        for (int sum_par: {0, 1}) {
            for (int i = 1; i <= n; i++) {
                pre_par_pi100_suf_par_cnt5[sum_par][par][i] = pre_par_pi100_suf_par_cnt5[sum_par][par][i - 1];
                if ((i + suf_cnt8[i]) % 2 == sum_par) {
                    // 10 ^ {-l - suf_cnt8_l - 2 * suf_cnt5_[i % 2]_l})
                    (pre_par_pi100_suf_par_cnt5[sum_par][par][i] += pi100_suf_par_cnt5[par][i] * ip10[i] % MOD
                        * modinv(p10_suf_cnt8[i]) % MOD) %= MOD;
                }
            }
        }
    }

    i64 ans5 = 0;
    for (int i = 1; i <= n; i++) {
        if (a[i] == 5) {
            i64 lans5 = ans5;
            int ipar = i % 2;
            for (int lpar: {0, 1}) for (int rpar: {0, 1}) {
                if ((ipar + lpar + rpar) % 2 == 0) {
                    ans5 += 4 * pre_par_pi100_suf_par_cnt5[lpar][ipar][i] * par_suf_p10[rpar][i] % MOD
                        * modinv(pi100_suf_par_cnt5[ipar][i]) % MOD;
                } else {
                    ans5 += 4 * pre_par_pi100_suf_par_cnt5[lpar][ipar][i] * par_suf_p10[rpar][i] % MOD * inv10 % MOD
                        * modinv(pi100_suf_par_cnt5[ipar][i]) % MOD;
                }
                ans5 %= MOD;
            }
            // cout << i << "[10 diff ans5]" << (MOD + ans5 - lans5) * 10 % MOD << endl;
        }
    }
    // cout << ans5 << ": ans5" << endl;
    (ans += ans5) %= MOD;

    cout << ans << "\n";

    // i64 ans_brute = 0;
    // for (int l = 1; l <= n; l++) {
    //     CountType cnt{};
    //     for (int r = l; r <= n; r++) {
    //         cnt[(r - l) % 2][a[r]]++;
    //         (ans_brute += brute_calc(cnt, r - l + 1)) %= MOD;
    //     }
    // }
    // cout << ans_brute << "\n";
}