题解:P12047 [USTCPC 2025] 翻转数字
有幸成为 USTCPC(包括线上、线下、锣鼓重现赛)这题的唯一通过者,因此来贴一下之前写好的题解。比赛完整题解之后会发布在我的 cnblogs。
当之无愧的防 AK 题,需要用到一连串的结论,用 Python 验证一下会更放心。
首先考虑操作可达什么串。当没有对称字符的概念时,经典的结论是奇偶下标可以分别乱序排列,但奇数下标不能和偶数下标互换。有对称字符概念时结论略有变化,只有非对称的字符下标奇偶性不能变,其余字符可以任意排列。
那么对于一个子串的答案,我们就有了一个贪心算法:首先 9 显然一定可以(在对应的奇偶位置)尽量往前填,之后 8 在能够形成合法串(即不占了 5 的位置)的前提下也能尽量往前填。接下来仍然 5 在对应的奇偶位置上贪心填;最后 1 和 0 都是对称的,一定是 1 在前面,0 在后面。这个算法可以
事实上我们可以继续优化到
但我们目前还没有利用随机性质呢。事实上目前算法竞赛题对随机性质的利用都很老套,无非就是大数定律那一套,两个概率一样的东西在
所以我们的做法肯定是对于短串先暴力算,对于长串利用这个性质去快速算。经过 Python 蒙特卡洛估计,当分界线取 64 时,概率已经在
事实上我第一天的想法就是利用这个性质去做增量更新然后
- 首先考虑对只有 01 的串如何计算答案。我们把答案按照贡献写成式子:
发现了吗?我们把一个位置的贡献拆成了两项可以预处理的求和之积,这样就
- 那么第一个挑战就来自于 9,它引入了奇偶性的考虑。我们对一个单独的
9 推式子:
我们发现
第二天起来就开始细化解法的同时写代码。然而写到 5 的时候我发现事情好像不那么简单,式子中涉及了同时与
总之,我们先根据上面的线性做法计算整个串符合假设时的答案,而对于短串暴力计算答案并减去可能错误计算的答案,整个题目就在
完整代码:
#include <bits/stdc++.h>
using namespace std;
using i64 = int64_t;
using u64 = uint64_t;
constexpr i64 MOD = 998244353, N = 5e4 + 500 + 10, B = 128 + 32;
i64 p10[N + 1]{}, ip10[N + 1]{}, m11[N + 1]{}, m1010[N + 1]{};
i64 sorted_arr[N + 1]{}, a[N + 1]{};
i64 inv10, inv100;
i64 p10_suf_cnt0[N + 1]{}, pre_p10_suf_cnt0[N + 1]{};
i64 p10_pre_cnt01[N + 1]{}, suf_p10_pre_cnt01[N + 1]{};
i64 p10_suf_cnt01[N + 1]{}, pre_p10_suf_cnt01[N + 1]{};
i64 p10_pre_cntall[N + 1]{}, suf_p10_pre_cntall[N + 1]{};
i64 suf_p10[N + 1]{};
i64 pi100_suf_par_cnt9[2][N + 1]{}, pre_par_pi100_suf_par_cnt9[2][2][N + 1]{};
i64 par_suf_p10[2][N + 1]{};
i64 suf_cnt8[N + 1]{}, p10_suf_cnt8[N + 1]{}, pre_p10_suf_cnt8[N + 1]{};
i64 pi100_suf_par_cnt5[2][N + 1]{}, pre_par_pi100_suf_par_cnt5[2][2][N + 1]{};
using CountType = array<array<int, 10>, 2>;
i64 modpow(i64 a, i64 b, i64 M) {
i64 res = 1;
while (b) {
if (b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
i64 modinv(i64 a) { return modpow(a, MOD - 2, MOD); }
void pre_calc() {
fill(a, a + N + 1, 0);
inv10 = modinv(10);
p10[0] = ip10[0] = 1;
for (int i = 1; i <= N; i++) {
p10[i] = (p10[i - 1] * 10) % MOD;
ip10[i] = (ip10[i - 1] * inv10) % MOD;
}
m11[0] = 0;
for (int i = 1; i <= N; i++) {
// 11111... (i 1s)
m11[i] = (m11[i - 1] * 10 + 1) % MOD;
}
inv100 = modinv(100);
m1010[1] = 1;
for (int i = 1; i <= N; i++) {
// 1.01010101... (i 1s)
m1010[i] = (m1010[i - 1] * inv100 + 1) % MOD;
}
}
i64 brute_calc(const CountType &cnt, int n) {
fill(sorted_arr, sorted_arr + n, -1);
array<int, 2> rem = {(n + 1) / 2, n / 2};
int p0 = 0, p1 = 1, p = 0;
for (int i = 0; i < cnt[0][9]; i++) {
sorted_arr[p0] = 9;
p0 += 2;
rem[0]--;
}
for (int i = 0; i < cnt[1][9]; i++) {
sorted_arr[p1] = 9;
p1 += 2;
rem[1]--;
}
p = min(p0, p1);
for (int i = 0; i < cnt[0][8] + cnt[1][8]; i++) {
while (true) {
while (sorted_arr[p] != -1) p++;
bool feasible = rem[p % 2] > cnt[p % 2][5];
if (feasible) {
sorted_arr[p] = 8;
rem[p % 2]--;
p++;
break;
}
p++;
}
}
for (int i = 0; i < cnt[0][5]; i++) {
while (sorted_arr[p0] != -1) p0 += 2;
sorted_arr[p0] = 5;
p0 += 2;
}
for (int i = 0; i < cnt[1][5]; i++) {
while (sorted_arr[p1] != -1) p1 += 2;
sorted_arr[p1] = 5;
p1 += 2;
}
for (int i = 0; i < cnt[0][1] + cnt[1][1]; i++) {
while (sorted_arr[p] != -1) p++;
sorted_arr[p] = 1;
p++;
}
for (int i = 0; i < cnt[0][0] + cnt[1][0]; i++) {
while (sorted_arr[p] != -1) p++;
sorted_arr[p] = 0;
p++;
}
i64 res = 0;
for (int i = 0; i < n; i++) {
if (sorted_arr[i] == -1) {
// for (int i = 0; i < n; i++) {
// cout << sorted_arr[i] << ", ";
// }
// cout << endl;
// for (int par: {0, 1})
// for (int x: {0, 1, 5, 8, 9}) {
// cout << par << " " << x << " " << cnt[par][x] << endl;
// }
exit(0);
}
res += p10[n - i - 1] * sorted_arr[i];
}
return res % MOD;
}
// This is O(1) but with huge constant
i64 errornous_calc(const CountType &cnt, int l, int r) {
// first calc 8...81...10...0
i64 c8 = cnt[0][8] + cnt[1][8] + cnt[0][9] + cnt[1][9],
c1 = cnt[0][1] + cnt[1][1] + cnt[0][5] + cnt[1][5], c0 = cnt[0][0] + cnt[1][0];
i64 vl = 8 * m11[c8], vm = m11[c1];
i64 bm = p10[c1], br = p10[c0];
i64 res = (vl * bm + vm) % MOD * br % MOD;
// then add 9 and 5
res += p10[r - l] * m1010[cnt[0][9]] % MOD; // 1.0101... * 10^(r - l)
if (r > l) {
res += p10[r - l - 1] * m1010[cnt[1][9]] % MOD;
} else {
res += ip10[l + 1 - r] * m1010[cnt[1][9]] % MOD;
}
i64 l5 = l + c8;
res += p10[r - l5] * m1010[cnt[c8 % 2][5]] * 4 % MOD; // 5 - 1
if (r > l5) {
res += p10[r - l5 - 1] * m1010[cnt[(c8 + 1) % 2][5]] * 4 % MOD;
} else {
res += ip10[l5 + 1 - r] * m1010[cnt[(c8 + 1) % 2][5]] * 4 % MOD;
}
return res % MOD;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
pre_calc();
string s; cin >> s;
int n = s.size();
// mt19937_64 rng(11451419198);
// int n = 5000;
// string s(n, '6');
// for (int i = 0; i < n; i++) {
// s[i] = "19580"[rng() % 5];
// }
for (int i = 1; i <= n; i++) {
a[i] = s[i - 1] - '0';
}
// CountType cnt{};
// for (int i = 1; i <= n; i++) {
// cnt[i % 2][a[i]]++;
// }
// cout << errornous_calc(cnt, 1, n);
// return 0;
i64 ans = 0;
for (int l = 1; l <= n; l++) {
CountType cnt{};
i64 now_ans = 0;
for (int r = l; r <= l + B && r <= n; r++) {
cnt[(r - l) % 2][a[r]]++;
now_ans = brute_calc(cnt, r - l + 1) % MOD;
now_ans += (MOD - errornous_calc(cnt, l, r)) % MOD;
// cout << l << " " << r << " " << errornous_calc(cnt, l, r) << "\n";
ans = (ans + now_ans) % MOD;
}
ans %= MOD;
if (l + B >= n) continue;
}
// cout << (ans % MOD) << ": less than B\n";
// we need some precalc first
// for example, the contribution of one 1 is
// sum_l sum_r 10^ {cnt0_left + cnt01_right}
// = sum_l sum_r 10 ^ {cnt0_suf_l - cnt0_suf_i + cnt01_pre_r - cnt01_pre_i}
// = sum_l (10 ^ cnt0_suf_l) * sum_r (10 ^ cnt01_pre_r) * 10 ^ (- cnt0_suf_i - cnt01_pre_i)
p10_suf_cnt0[n + 1] = 1;
for (int i = n; i >= 1; i--) {
p10_suf_cnt0[i] = p10_suf_cnt0[i + 1];
if (a[i] == 0) (p10_suf_cnt0[i] *= 10) %= MOD;
}
for (int i = 1; i <= n; i++) {
(pre_p10_suf_cnt0[i] = pre_p10_suf_cnt0[i - 1] + p10_suf_cnt0[i]) %= MOD;
}
p10_pre_cnt01[0] = 1;
for (int i = 1; i <= n; i++) {
p10_pre_cnt01[i] = p10_pre_cnt01[i - 1];
if (a[i] == 0 || a[i] == 1 || a[i] == 5) (p10_pre_cnt01[i] *= 10) %= MOD;
}
for (int i = n; i >= 1; i--) {
(suf_p10_pre_cnt01[i] = suf_p10_pre_cnt01[i + 1] + p10_pre_cnt01[i]) %= MOD;
}
i64 ans1 = 0;
for (int i = 1; i <= n; i++) {
if (a[i] == 1 || a[i] == 5) {
ans1 += pre_p10_suf_cnt0[i] * suf_p10_pre_cnt01[i] % MOD *
modinv(p10_suf_cnt0[i]) % MOD * modinv(p10_pre_cnt01[i]) % MOD;
// cout << i << " " << p10_suf_cnt0[i] << " " << pre_p10_suf_cnt0[i] << endl;
// cout << i << "[ans1]" << pre_p10_suf_cnt0[i] * suf_p10_pre_cnt01[i] % MOD *
// modinv(p10_suf_cnt0[i]) % MOD * modinv(p10_pre_cnt01[i]) % MOD << endl;
ans1 %= MOD;
}
}
(ans += ans1) %= MOD;
// cout << ans1 << ": ans1" << endl;
// for 8 we have the same logic.
// sum_l (10 ^ cnt01_suf_l) * sum_r (10 ^ cnt_all_pre_r) * 10 ^ (- cnt01_suf_i - cnt_all_pre_i)
p10_suf_cnt01[n + 1] = 1;
for (int i = n; i >= 1; i--) {
p10_suf_cnt01[i] = p10_suf_cnt01[i + 1];
if (a[i] == 0 || a[i] == 1 || a[i] == 5) (p10_suf_cnt01[i] *= 10) %= MOD;
}
for (int i = 1; i <= n; i++) {
(pre_p10_suf_cnt01[i] = pre_p10_suf_cnt01[i - 1] + p10_suf_cnt01[i]) %= MOD;
}
p10_pre_cntall[0] = 1;
for (int i = 1; i <= n; i++) {
p10_pre_cntall[i] = p10_pre_cntall[i - 1];
(p10_pre_cntall[i] *= 10) %= MOD;
}
for (int i = n; i >= 1; i--) {
(suf_p10_pre_cntall[i] = suf_p10_pre_cntall[i + 1] + p10_pre_cntall[i]) %= MOD;
}
i64 ans8 = 0;
for (int i = 1; i <= n; i++) {
if (a[i] == 9 || a[i] == 8) {
ans8 += 8 * pre_p10_suf_cnt01[i] * suf_p10_pre_cntall[i] % MOD *
modinv(p10_suf_cnt01[i]) % MOD * modinv(p10_pre_cntall[i]) % MOD;
ans8 %= MOD;
}
}
// cout << ans8 << ": ans8" << endl;
(ans += ans8) %= MOD;
// for 9 it's harder...
// if 9 and l have the same parity:
// sum_l sum_r 10 ^ {r - l - 2 * cnt_left[i % 2][9])}
// = sum_l sum_r 10 ^ {r - l - 2 * suf_cnt9_[i % 2]_l + 2 * suf_cnt9_[i % 2]_i}
// = (sum_l_{same parity} 10 ^ {-l - 2 * suf_cnt9_[i % 2]_l}) * (sum_r 10^r) * 10 ^ {2 * suf_cnt9_[i % 2]_i}
// otherwise:
// (sum_l_{diff parity} 10 ^ {-l - 2 * suf_cnt9_[i % 2]_l}) * (sum_r 10^r) / 10 * 10 ^ {2 * suf_cnt9_[i % 2]_i}
for (int i = n; i >= 1; i--) {
suf_p10[i] = (suf_p10[i + 1] + p10[i]) % MOD;
}
for (int par: {0, 1}) {
pi100_suf_par_cnt9[par][n + 1] = 1;
for (int i = n; i >= 1; i--) {
pi100_suf_par_cnt9[par][i] = pi100_suf_par_cnt9[par][i + 1];
if (i % 2 == par && a[i] == 9) {
(pi100_suf_par_cnt9[par][i] *= inv100) %= MOD;
}
}
for (int sum_par: {0, 1}) {
for (int i = 1; i <= n; i++) {
pre_par_pi100_suf_par_cnt9[sum_par][par][i] = pre_par_pi100_suf_par_cnt9[sum_par][par][i - 1];
if (i % 2 == sum_par) {
(pre_par_pi100_suf_par_cnt9[sum_par][par][i] += pi100_suf_par_cnt9[par][i] * ip10[i] % MOD) %= MOD;
}
}
}
}
i64 ans9 = 0;
for (int i = 1; i <= n; i++) {
if (a[i] == 9) {
ans9 += pre_par_pi100_suf_par_cnt9[i % 2][i % 2][i] * suf_p10[i] % MOD
* modinv(pi100_suf_par_cnt9[i % 2][i]) % MOD;
ans9 += pre_par_pi100_suf_par_cnt9[(i % 2) ^ 1][i % 2][i] * suf_p10[i] % MOD * inv10 % MOD
* modinv(pi100_suf_par_cnt9[i % 2][i]) % MOD;
ans9 %= MOD;
// cout << i << " " << ans9 << "\n";
}
}
(ans += ans9) %= MOD;
// cout << ans9 << ": ans9" << endl;
// The most difficult part [panic]
// The contribution of 5
// if i and (l - cnt8) have the same parity:
// sum_l sum_r 10 ^ {r - l - cnt8 - 2 * cnt_left[i % 2][5])}
// = sum_l sum_r 10 ^ {r - l - suf_cnt8_l - 2 * suf_cnt5_[i % 2]_l + 2 * suf_cnt5_[i % 2]_i + suf_cnt8_(r+1))}
// = (sum_l_{same parity} 10 ^ {-l - suf_cnt8_l - 2 * suf_cnt5_[i % 2]_l}) * (sum_r_{same parity} 10^ {r + suf_cnt8_(r+1)}) * 100 ^ suf_cnt5_[i % 2]_i
// otherwise just / 10
// How to determine the parity? (i + suf_cnt_8[r + 1]) and (l + suf_cnt_8[l]) should have the same parity.
// It doesn't seem as hard as it looks.
suf_cnt8[n + 1] = 0; p10_suf_cnt8[n + 1] = 1;
for (int i = n; i >= 1; i--) {
suf_cnt8[i] = suf_cnt8[i + 1]; p10_suf_cnt8[i] = p10_suf_cnt8[i + 1];
if (a[i] == 9 || a[i] == 8) {
suf_cnt8[i] += 1;
(p10_suf_cnt8[i] *= 10) %= MOD;
}
}
pre_p10_suf_cnt8[0] = 0;
for (int i = 1; i <= n; i++) {
(pre_p10_suf_cnt8[i] = pre_p10_suf_cnt8[i - 1] + p10_suf_cnt8[i]) %= MOD;
}
for (int par: {0, 1}) {
for (int i = n; i >= 1; i--) {
par_suf_p10[par][i] = par_suf_p10[par][i + 1];
if (suf_cnt8[(i + 1)] % 2 == par) {
// sum_r_{same parity} 10 ^ {r + suf_cnt8_(r+1)}
(par_suf_p10[par][i] += p10[i] * p10_suf_cnt8[i + 1] % MOD) %= MOD;
}
}
}
for (int par: {0, 1}) {
pi100_suf_par_cnt5[par][n + 1] = 1;
for (int i = n; i >= 1; i--) {
pi100_suf_par_cnt5[par][i] = pi100_suf_par_cnt5[par][i + 1];
if (i % 2 == par && a[i] == 5) {
(pi100_suf_par_cnt5[par][i] *= inv100) %= MOD;
}
}
for (int sum_par: {0, 1}) {
for (int i = 1; i <= n; i++) {
pre_par_pi100_suf_par_cnt5[sum_par][par][i] = pre_par_pi100_suf_par_cnt5[sum_par][par][i - 1];
if ((i + suf_cnt8[i]) % 2 == sum_par) {
// 10 ^ {-l - suf_cnt8_l - 2 * suf_cnt5_[i % 2]_l})
(pre_par_pi100_suf_par_cnt5[sum_par][par][i] += pi100_suf_par_cnt5[par][i] * ip10[i] % MOD
* modinv(p10_suf_cnt8[i]) % MOD) %= MOD;
}
}
}
}
i64 ans5 = 0;
for (int i = 1; i <= n; i++) {
if (a[i] == 5) {
i64 lans5 = ans5;
int ipar = i % 2;
for (int lpar: {0, 1}) for (int rpar: {0, 1}) {
if ((ipar + lpar + rpar) % 2 == 0) {
ans5 += 4 * pre_par_pi100_suf_par_cnt5[lpar][ipar][i] * par_suf_p10[rpar][i] % MOD
* modinv(pi100_suf_par_cnt5[ipar][i]) % MOD;
} else {
ans5 += 4 * pre_par_pi100_suf_par_cnt5[lpar][ipar][i] * par_suf_p10[rpar][i] % MOD * inv10 % MOD
* modinv(pi100_suf_par_cnt5[ipar][i]) % MOD;
}
ans5 %= MOD;
}
// cout << i << "[10 diff ans5]" << (MOD + ans5 - lans5) * 10 % MOD << endl;
}
}
// cout << ans5 << ": ans5" << endl;
(ans += ans5) %= MOD;
cout << ans << "\n";
// i64 ans_brute = 0;
// for (int l = 1; l <= n; l++) {
// CountType cnt{};
// for (int r = l; r <= n; r++) {
// cnt[(r - l) % 2][a[r]]++;
// (ans_brute += brute_calc(cnt, r - l + 1)) %= MOD;
// }
// }
// cout << ans_brute << "\n";
}