P9171 [省选联考 2023] 染色数组

· · 题解

P9171 [省选联考 2023] 染色数组

很明显,要想会做第二问,首先得会做第一问。要想会做第一问,首先得找到方便统计答案的方式刻画完美数组。

刻画完美数组

从定义入手,完美数组是具有至少两种染色方案的数组,显然存在一个元素在两种方案中分别被染上了红色和绿色。这启发我们考虑一个元素可以被染成红色和绿色的充要条件。

将四个条件拼起来(其中两个条件分别弱于另外两个条件),得

结论 1:元素 a_i 可以被染成红色和绿色,当且仅当其满足以下所有条件:

这两个条件看起来非常棒,满足条件的元素性质一定很好。称这些元素为 关键元素,对应下标集合 S

接下来探究关键元素的性质。

考虑两个关键元素 a_l, a_rl < r):

结论 2:完美数组的 S 为一段 非空 区间 [l, r],且满足下列条件之一:

关键元素的形态刻画完了,再考虑非关键元素的形态,易知

结论 3:设 L = \min(a_l, a_r)R = \max(a_l, a_r),则

Q(r, x, y) 表示条件:[1, r) 中小于 x 的数递增,大于 y 的数递减,且不存在值落在 [x, y] 的元素。

第一问

我们已经得到充分的性质解决第一问。考虑在 a_r 处统计答案。

综合上述讨论,发现要求满足 Q(r, x, y) 的方案数。对于后缀,限制是类似的。

第二问

接下来考虑第二问。

计数题里面套了个 \max,怎么看都觉得阴间。

在考虑每个元素产生的贡献时,其前面的元素的颜色是不重要的。因此,考虑一个完美数组,计算其得分时,只有关键元素的颜色会对 \max 产生影响。

结论 4:完美数组的得分为:

  • 对于情况 1,任意一种染色方案的得分。
  • 对于情况 2 和 3,将 [1, r)< a_r 的元素和 (r, n]> a_r 的元素染成红色,[1, r)> a_r 的元素和 (r, n]< a_r 的元素染成绿色,a_r 染成贡献较大的颜色的染色方案的得分。

这两种情况可以统一起来,但计算时仍需分类讨论。原因是情况 1 和情况 2,3 本质不同。

以下讨论情况 2 和 3

枚举 $r$,$a_r$ 和 $a_{r + 1}$。**不妨设 $\boldsymbol {a_r < a_{r + 1}}$,反之同理**。 - 对于 $[1, r)$:类似第一问 DP。维护所有完美数组的当前得分之和,红色元素数量之和,绿色元素数量之和,以及方案数。时间复杂度 $\mathcal{O}(nm ^ 2)$。 - 对于 $r$:若从 DP 入手,则需再记录一维表示红色元素的数量,有点麻烦。 考虑组合方法。注意首先检查 $[1, r)$ 已经确定的部分是否符合要求。 设 $[1, r)$ 已经确定的部分 $< a_r$ 的元素数量为 $c_d$,最大值为 $v_d$,$> a_r$ 的元素数量为 $c_u$,最小值为 $v_u$。设还有 $k$ 个位置没有确定。枚举未确定部分 $< a_r$ 的元素数量 $i$,则贡献为 $$ \sum_{i = 0} ^ k \binom {k} {i} \binom {a_r - v_d - 1} {i} \binom {v_u - a_r - 1} {k - i} \max((m - a_r + 1) (c_u + k - i), a_r (c_d + i)) $$ 再减去不合法(使得 $a_{r + 1}\in S$)的贡献,对应第一问的容斥。 $$ \sum_{i = 0} ^ k \binom k i \binom {a_r - v_d - 1} {i} \binom {v_u - a_{r + 1} - 1} {k - i} \max((m - a_r + 1) (c_u + k - i), a_r(c_d + i)) $$ 不要忘记乘以对应方案数。时间复杂度 $\mathcal{O}(n ^ 2m ^ 2)$。 - 对于 $(r, n]$:根据点对贡献的定义,位置在 $r$ 之后的元素只会和位置在 $r$ 之前的元素产生贡献。这两部分相对独立。 计算每个数在 $[1, r)$ 中出现的方案数(只安排 $a_{1\sim r - 1}$,注意 $a_r$ 和 $a_{r + 1}$ 已确定)$c_i$。枚举产生贡献的数 $a_j$(只关心数值,不关心位置),根据 $a_j$ 和 $a_r$ 的相对大小关系可知其颜色。不妨设 $a_j$ 被染成绿色($a_r > a_j$),则 $a_j$ 产生的贡献为:$a_j$ 在 $(r, n]$ 中出现的方案数(只安排 $a_{r + 1\sim n}$)$d_{a_j}$,乘以 $a_j$,再乘以 $\sum_{i = 1} ^ {a_j - 1} c_i$。 计算 $c_i$: - 若 $i$ 在 $[1, r)$ 已经确定确定的部分出现,则 $c_i$ 等于安排 $a_{1\sim r - 1}$ 的方案数。 - 否则,对于 $i < a_r$,要求 $v_d < i < a_r$,可以认为 $i$ 在所有方案中是均匀的(感性理解),即所有 $v_d < i < a_r$ 的 $c_i$ 相等。算出 $\sum c_i$ 即 $$ \sum_{i = 0} ^ k i\binom k i \binom {a_r - v_d - 1} {i} \left(\binom {v_u - a_{r} - 1} {k - i} - \binom {v_u - a_{r + 1} - 1} {k - i} \right) $$ 对于 $i > a_r$,分别对容斥的两部分算出平均值再相减,而不是先相减再取平均。 计算 $d_{a_j}$: - 类似地,若 $a_j$ 在 $(r, n]$ 已经确定的部分出现或等于枚举的 $a_{r + 1}$,则 $d_{a_j}$ 等于安排 $a_{r + 2\sim n}$ 的方案数。 - 否则直接对 $< x$ 和 $> y$ 的两部分分别算总数取平均(后缀不需要容斥)。用组合数算总数。 注意这里不是 $< a_r$ 和 $> a_{r + 1}$,因为当 $r < t$ 时,已经确定的数对未确定的数有更紧的界。 直接做的复杂度是 $\mathcal{O}(nm ^ 3)$,因为枚举了 $a_j$。考虑优化复杂度。 ##### 优化复杂度 当 $r < t$ 时,$a_r$ 和 $a_{r + 1}$ 均不需要枚举。当 $r\geq t$ 时,$[1, a_r)$,$a_{r + 1}$ 和 $(a_{r + 1}, n]$ 的 $d$ 值分别相同,枚举太浪费时间。 - 当 $r < t$ 时,可直接暴力计算。 - 当 $r\geq t$ 时,转换视角,考虑 $c_i$ 和所有 $d_j$ 之间的贡献,后者形如 $j$ 属于某个区间范围内的 $\sum d_j j$ 或 $\sum d_j (m - j + 1)$。因 $d$ 值为三个连续段,可以 $\mathcal{O}(1)$ 计算。 在 $a_{1\sim t}$ 中出现过的 $i$(共 $t$ 个)的 $c_i$ 的贡献分别 $\mathcal{O}(1)$ 计算,而没出现过的 $c_i$ 依然为取值相同的三段区间,$\mathcal{O}(1)$ 计算一段 $c$ 值相同的区间和一段 $d$ 值相同的区间之间的贡献。 总时间复杂度 $\mathcal{O}(n ^ 2 m ^ 2)$。 一些细节: - 在使用组合方法计算前缀或后缀方案数时,先判断已经确定的数是否合法。 - 特判 $r = n$。 - 情况 1 的第二问和情况 2,3 的第二问略有不同,可类似推导,细节处理一定要仔细,不能想当然。 - 可以对 $i < t$ 预处理 $r = i$ 时后缀是否合法及对应方案数,方便处理。 - 题解的大部分地方都只讨论了一种情况,因为另一种情况是对称的。繁琐的分类讨论导致代码量和细节很多,想清楚再开始写。 ```cpp #include <bits/stdc++.h> using namespace std; using ll = long long; bool Mbe; constexpr int N = 50 + 5, M = 200 + 5, mod = 998244353; void cmin(int &x, int y) {x = x < y ? x : y;} void cmax(int &x, int y) {x = x > y ? x : y;} void addt(int &x, int y) {x += y, x >= mod && (x -= mod);} void subf(int &x, int y) {x -= y, x < 0 && (x += mod);} int add(int x, int y) {return x += y, x >= mod && (x -= mod), x;} int sub(int x, int y) {return x -= y, x < 0 && (x += mod), x;} int S(int l, int r) {return (l + r) * (r - l + 1) / 2;} int S2(int x) {return x * (x + 1) * (x + x + 1) / 6;} int S2(int l, int r) {return S2(r) - S2(l - 1);} int n, m, t, ans1, ans2, a[N], b[N], d[M], up[N], dn[N], inv[M], val1[N], val2[N], C[M][M], g[N][M][M], gi[N][M][M]; bool chk(int i, int v) {return !a[i] || a[i] == v;} int getg(int k, int x, int y) {return x < 0 || y < 0 ? 0 : x < y ? g[k][x][y] : g[k][y][x];} int getgi(int k, int x, int y) {return x < 0 || y < 0 ? 0 : 1ll * (x < y ? gi[k][x][y] : sub(1ll * g[k][y][x] * k % mod, gi[k][y][x])) * inv[x] % mod;} // 想一想 x > y 时为什么要将 i 变成 k - i. int csuf(int p, int x, int y) {return p - 2 < t ? max(0, b[p - 2]) : getg(n - (p - 1), x - 1, m - y);} struct dat { int f, su, lo, hi; dat operator + (const dat &z) const {return {add(f, z.f), add(su, z.su), add(lo, z.lo), add(hi, z.hi)};} dat operator - (const dat &z) const {return {sub(f, z.f), sub(su, z.su), sub(lo, z.lo), sub(hi, z.hi)};} dat addlo(int z) {return {f, add(su, 1ll * (m - z + 1) * hi % mod), add(lo, f), hi};} dat addhi(int z) {return {f, add(su, 1ll * z * lo % mod), lo, add(hi, f)};} } f[N][M][M], E; struct itv {int l, r, val;}; void solve() { cin >> n >> m >> t, ans1 = ans2 = 0; memset(a, 0, sizeof(a)); for(int i = 1; i <= t; i++) cin >> a[i]; for(int i = 1; i < t; i++) { dn[i] = min(a[i], a[i + 1]), up[i] = max(a[i], a[i + 1]); for(int j = i + 2; j <= t; j++) a[j] > up[i] ? up[i] = a[j] : dn[i] = (a[j] < dn[i] ? a[j] : -1); b[i] = dn[i] != -1 ? getg(n - t, dn[i] - 1, m - up[i]) : -1; // -1 表示 r 不合法. } memset(f, 0, sizeof(f)); for(int x = 0; x <= m; x++) for(int y = m + 1; y; y--) f[0][x][y] = {1, 0, 0, 0}; for(int i = 1; i < n; i++) { // DP 部分. for(int v = 1; v <= m; v++) if(chk(i, v)) { for(int y = m + 1; y > v; y--) f[i][v][y] = (f[i - 1][v - 1][y] - f[i - 1][v - 1][y + 1]).addlo(v); for(int x = 0; x <= m; x++) f[i][x][v] = f[i][x][v] + (f[i - 1][x][v + 1] - (x ? f[i - 1][x - 1][v + 1] : E)).addhi(v); } for(int x = 0; x <= m; x++) for(int y = m + 1; y > x; y--) // 二维前缀和. f[i][x][y] = f[i][x][y] + f[i][x][y + 1] + (x ? f[i][x - 1][y] - f[i][x - 1][y + 1] : E); } for(int r = 1; r <= n; r++) for(int v = 1; v <= m; v++) if(chk(r, v)) { int k = max(0, r - 1 - t), cd = 0, vd = 0, cu = 0, vu = m + 1, equal = 0; int pre = 0, suf = 0, tot = 0, sufok = r >= t || r < t && b[r] != -1; for(int i = 1; i < r && i <= t && vd != -1 && vu != -1; i++) { if(a[i] < v) cd++, a[i] > vd ? vd = a[i] : vd = -1; if(a[i] == v) equal++; if(a[i] > v) cu++, a[i] < vu ? vu = a[i] : vu = -1; } if(vu == -1 || vd == -1 || equal > 1 || equal == 1 && a[r - 1] != v) continue; // 前缀不合法, 注意情况 1 的 equal 可以等于 1. auto calc = [&](int v2) { // (r, n], 被调用 O(nm ^ 2) 次. if(r == 1 || r == n) return; // r = 1 或 r = n 一定不产生贡献. if(r < t) { // 总共调用 O(n) 次, 随便暴力. memset(d, 0, sizeof(d)); if(v < v2 && vu > v2 || v > v2 && vd < v2) return; // 判 r + 1 \in S 不合法. int hi = getgi(n - t, m - up[r], dn[r] - 1), lo = getgi(n - t, dn[r] - 1, m - up[r]); for(int i = up[r] + 1; i <= m; i++) d[i] = 1ll * hi * (m - i + 1) % mod; for(int i = dn[r] - 1; i >= 1; i--) d[i] = 1ll * lo * i % mod; for(int i = r + 1; i <= t; i++) addt(d[a[i]], 1ll * suf * (a[i] < v ? a[i] : m - a[i] + 1) % mod); for(int i = max(v, v2) + 1; i <= m; i++) addt(d[i], d[i - 1]); for(int i = min(v, v2) - 1; i; i--) addt(d[i], d[i + 1]); for(int i = 1; i < r - (v == v2); i++) addt(ans2, a[i] < v ? d[a[i] + 1] : d[a[i] - 1]); // 已经确定的数的 c = 1. } else { static itv c[3], d[3]; int C = 0, D = 0; if(v != v2) { d[D++] = {1, min(v, v2) - 1, getgi(n - (r + 1), min(v, v2) - 1, m - max(v, v2))}; d[D++] = {max(v, v2) + 1, m, getgi(n - (r + 1), m - max(v, v2), min(v, v2) - 1)}; d[D++] = {v2, v2, suf}; if(v < v2) { if(vd + 1 < v) c[C++] = {vd + 1, v - 1, sub(getgi(k, v - vd - 1, vu - v - 1), getgi(k, v - vd - 1, vu - v2 - 1))}; if(v2 + 1 < vu) c[C++] = {v2 + 1, vu - 1, sub(getgi(k, vu - v - 1, v - vd - 1), getgi(k, vu - v2 - 1, v - vd - 1))}; } else { if(v + 1 < vu) c[C++] = {v + 1, vu - 1, sub(getgi(k, vu - v - 1, v - vd - 1), getgi(k, vu - v - 1, v2 - vd - 1))}; if(vd + 1 < v2) c[C++] = {vd + 1, v2 - 1, sub(getgi(k, v - vd - 1, vu - v - 1), getgi(k, v2 - vd - 1, vu - v - 1))}; } } else { d[D++] = {1, v - 1, getgi(n - r, v - 1, m - v)}; d[D++] = {v + 1, m, getgi(n - r, m - v, v - 1)}; if(vd + 1 < v && k) c[C++] = {vd + 1, v - 1, getgi(k - 1, v - vd - 1, vu - v - 1)}; // 这里是 k - 1! if(v + 1 < vu && k) c[C++] = {v + 1, vu - 1, getgi(k - 1, vu - v - 1, v - vd - 1)}; } int tot = 0; for(int i = 1; i < r - (v == v2) && i <= t; i++) for(int j = 0; j < D; j++) { // 计算已经确定的 c_i 的贡献. const itv &I = d[j]; if(a[i] < v && I.l < v && I.r > a[i]) addt(tot, 1ll * S(max(I.l, a[i] + 1), I.r) * I.val % mod); if(a[i] > v && I.l > v && I.l < a[i]) addt(tot, 1ll * S(m - min(I.r, a[i] - 1) + 1, m - I.l + 1) * I.val % mod); } addt(ans2, 1ll * tot * pre % mod); for(int i = 0; i < C; i++) for(int j = 0; j < D; j++) { const itv &I = c[i], &J = d[j]; int coef = 0; if(I.l < v && J.l < v) { // c 和 d 的区间, 要么后者包含前者, 要么后者为 [v2, v2] if(J.l <= I.l && I.r <= J.r) coef += (I.r - I.l + 1) * J.r * (J.r + 1) - S2(I.l, I.r) - S(I.l, I.r) >> 1; else if(I.r < J.l) coef += (I.r - I.l + 1) * J.l; // J.l = J.r = v2 } if(I.l > v && J.l > v) { if(J.l <= I.l && I.r <= J.r) coef += S(I.l - J.l, I.r - J.l) * (m + 1) - (S2(I.l, I.r) - S(I.l, I.r) - (I.r - I.l + 1) * (J.l - 1) * J.l >> 1); else if(J.r < I.l) coef += (I.r - I.l + 1) * (m - J.l + 1); } addt(ans2, 1ll * coef * I.val % mod * J.val % mod); } } }; if(r > 1 && chk(r - 1, v) && sufok) { suf = csuf(r + 1, v, v); addt(ans1, 1ll * (pre = f[r - 2][v - 1][v + 1].f) * suf % mod); addt(ans2, 1ll * f[r - 2][v - 1][v + 1].addlo(v).su * suf % mod); // [1, r) addt(ans2, 1ll * f[r - 2][v - 1][v + 1].lo * v % mod * suf % mod); // r calc(v); } if(equal) continue; // 情况 2 和 3 要求前缀没有等于 a[r] 的元素. for(int i = 0, val; i <= k; i++) { // 预处理经常要用的式子, 减小常数. val = 1ll * C[k][i] * max((m - v + 1) * (cu + k - i), v * (cd + i)) % mod; val1[i] = 1ll * val * C[v - vd - 1][i] % mod, val2[i] = 1ll * val * C[vu - v - 1][k - i] % mod; } if(r == n) { addt(ans1, f[r - 1][v - 1][v + 1].f); addt(ans2, f[r - 1][v - 1][v + 1].su); // [1, r) for(int i = 0; i <= k; i++) if(i < v - vd && k - i < vu - v) // r addt(ans2, 1ll * val1[i] * C[vu - v - 1][k - i] % mod); } else if(sufok) { for(int v2 = v + 1; v2 <= m; v2++) if(chk(r + 1, v2)) { suf = csuf(r + 2, v, v2), tot = 0; addt(ans1, 1ll * (pre = (f[r - 1][v - 1][v + 1] - f[r - 1][v - 1][v2 + 1]).f) * suf % mod); addt(tot, (f[r - 1][v - 1][v + 1] - f[r - 1][v - 1][v2 + 1]).su); // [1, r) for(int i = 0; i <= k; i++) if(i < v - vd && k - i < vu - v) // r addt(tot, 1ll * val1[i] * sub(C[vu - v - 1][k - i], k - i < vu - v2 ? C[vu - v2 - 1][k - i] : 0) % mod); calc(v2), addt(ans2, 1ll * tot * suf % mod); } for(int v2 = 1; v2 < v; v2++) if(chk(r + 1, v2)) { suf = csuf(r + 2, v2, v), tot = 0; addt(ans1, 1ll * (pre = (f[r - 1][v - 1][v + 1] - f[r - 1][v2 - 1][v + 1]).f) * suf % mod); addt(tot, (f[r - 1][v - 1][v + 1] - f[r - 1][v2 - 1][v + 1]).su); // [1, r) for(int i = 0; i <= k; i++) if(i < v - vd && k - i < vu - v) // r addt(tot, 1ll * val2[i] * sub(C[v - vd - 1][i], i < v2 - vd ? C[v2 - vd - 1][i] : 0) % mod); calc(v2), addt(ans2, 1ll * tot * suf % mod); } } } cout << ans1 << " " << ans2 << "\n"; } bool Med; int main() { fprintf(stderr, "%.4lf\n", (&Mbe - &Med) / 1048576.0); #ifdef ALEX_WEI FILE* IN = freopen("color.in", "r", stdin); FILE* OUT = freopen("color.out", "w", stdout); #endif ios::sync_with_stdio(0), cin.tie(0), cout.tie(0); inv[1] = 1; for(int i = 2; i < M; i++) inv[i] = mod - 1ll * mod / i * inv[mod % i] % mod; for(int i = 0; i < M; i++) for(int j = 0; j <= i; j++) C[i][j] = j ? add(C[i - 1][j - 1], C[i - 1][j]) : 1; for(int k = 0; k < N; k++) for(int x = 0; x < M; x++) for(int y = x; y + x < M; y++) for(int i = max(0, k - y), val; i <= min(x, k); i++) addt(g[k][x][y], val = 1ll * C[k][i] * C[x][i] % mod * C[y][k - i] % mod), addt(gi[k][x][y], 1ll * i * val % mod); int T; cin >> T; while(T--) solve(); cerr << 1e3 * clock() / CLOCKS_PER_SEC << " ms\n"; return 0; } /* g++ color.cpp -o color -std=c++14 -O2 -DALEX_WEI v = v2 的 k -> k - 1. 加入判断 equal = 1 && a[r - 1] != v. 将 S(I.l - J.l, I.r - J.l) * m 改成 S(I.l - J.l, I.r - J.l) * (m + 1). */ ```