P9171 [省选联考 2023] 染色数组
Alex_Wei
·
·
题解
P9171 [省选联考 2023] 染色数组
很明显,要想会做第二问,首先得会做第一问。要想会做第一问,首先得找到方便统计答案的方式刻画完美数组。
刻画完美数组
从定义入手,完美数组是具有至少两种染色方案的数组,显然存在一个元素在两种方案中分别被染上了红色和绿色。这启发我们考虑一个元素可以被染成红色和绿色的充要条件。
- 元素 a_i 被染成红色,那么
- 元素 a_i 被染成绿色,那么
将四个条件拼起来(其中两个条件分别弱于另外两个条件),得
结论 1:元素 a_i 可以被染成红色和绿色,当且仅当其满足以下所有条件:
这两个条件看起来非常棒,满足条件的元素性质一定很好。称这些元素为 关键元素,对应下标集合 S。
接下来探究关键元素的性质。
考虑两个关键元素 a_l, a_r(l < r):
- 如果 a_l = a_r,那么它们只能相邻,即 l + 1 = r,否则易证不合法。
- 如果 a_l > a_r,那么根据 a_l 和 a_r 是关键元素,可知 a_{l + 1\sim r - 1} 只能落在 [a_r + 1, a_l - 1],则 a_l > a_{l + 1} > \cdots > a_r。进一步推出 \forall i\in [l, r], i\in S。
- 如果 a_l < a_r,同理有 a_l < a_{l + 1} < \cdots < a_r,且 \forall i\in [l, r], i\in S。
结论 2:完美数组的 S 为一段 非空 区间 [l, r],且满足下列条件之一:
关键元素的形态刻画完了,再考虑非关键元素的形态,易知
结论 3:设 L = \min(a_l, a_r),R = \max(a_l, a_r),则
设 Q(r, x, y) 表示条件:[1, r) 中小于 x 的数递增,大于 y 的数递减,且不存在值落在 [x, y] 的元素。
第一问
我们已经得到充分的性质解决第一问。考虑在 a_r 处统计答案。
-
对于结论 2 的情况 1,枚举 r 和 a_r,根据结论 3 求出安排对应前缀 [1, r - 1) 和后缀 (r, n] 的方案数。
-
对于结论 2 的情况 2 和 3,枚举 r 和 a_r,则为了让 r + 1\notin S:
- 若 a_r < a_{r + 1},则要求 [1, r) 中 > a_r 的最小值不大于 a_{r + 1},否则无论是情况 2 还是情况 3,根据 a_r 是关键元素均可推出 a_{r + 1} 是关键元素,反之一定有 r + 1\notin S。
直接算不好算,容斥一下就是用 Q(r, a_r, a_r) 的方案数减去 Q(r, a_r, a_{r + 1}) 的方案数。这里及下文的方案数,均指不考虑 a_{r + 2\sim n},只考虑 a_{1\sim r - 1} 的方案数。
这说明我们还要枚举 a_{r + 1}。
综合上述讨论,发现要求满足 Q(r, x, y) 的方案数。对于后缀,限制是类似的。
- 设 f_{i, x, y} 表示 [1, i] 中满足 Q(i, x + 1, y - 1) 且存在 x 和 y 的方案数。初始递增序列开头为 0,递减序列开头为 m + 1,即初始化 f_{0, 0, m + 1} = 1。枚举 a_{i + 1} 转移结合前 / 后缀和优化做到 \mathcal{O}(nm ^ 2) 计算 f,对每个 f_i 做二维前缀和。
- 也可以用组合方法计算 f,接下来会讲,且非常重要。这样复杂度是 \mathcal{O}(n ^ 2 m ^ 2)。
第二问
接下来考虑第二问。
计数题里面套了个 \max,怎么看都觉得阴间。
在考虑每个元素产生的贡献时,其前面的元素的颜色是不重要的。因此,考虑一个完美数组,计算其得分时,只有关键元素的颜色会对 \max 产生影响。
-
对于情况 1,两个关键元素一红一绿,因值相等故两种染色方案得分相等。
-
对于情况 2,只有至多一个关键元素被染成绿色。考虑被染成绿色的关键元素,若其不为 a_r,则将其染成红色,下一个元素染成绿色,染色方案的得分不降。
因此,a_l\sim a_{r - 1} 被染成红色,只有 a_r 的颜色对 \max 产生影响。而 a_r 是要枚举的,所以它的贡献可在枚举时一并刻画,这看起来很棒。
-
对于情况 3,同理。
结论 4:完美数组的得分为:
- 对于情况 1,任意一种染色方案的得分。
- 对于情况 2 和 3,将 [1, r) 中 < a_r 的元素和 (r, n] 中 > a_r 的元素染成红色,[1, r) 中 > a_r 的元素和 (r, n] 中 < a_r 的元素染成绿色,a_r 染成贡献较大的颜色的染色方案的得分。
这两种情况可以统一起来,但计算时仍需分类讨论。原因是情况 1 和情况 2,3 本质不同。
以下讨论情况 2 和 3。
枚举 $r$,$a_r$ 和 $a_{r + 1}$。**不妨设 $\boldsymbol {a_r < a_{r + 1}}$,反之同理**。
- 对于 $[1, r)$:类似第一问 DP。维护所有完美数组的当前得分之和,红色元素数量之和,绿色元素数量之和,以及方案数。时间复杂度 $\mathcal{O}(nm ^ 2)$。
- 对于 $r$:若从 DP 入手,则需再记录一维表示红色元素的数量,有点麻烦。
考虑组合方法。注意首先检查 $[1, r)$ 已经确定的部分是否符合要求。
设 $[1, r)$ 已经确定的部分 $< a_r$ 的元素数量为 $c_d$,最大值为 $v_d$,$> a_r$ 的元素数量为 $c_u$,最小值为 $v_u$。设还有 $k$ 个位置没有确定。枚举未确定部分 $< a_r$ 的元素数量 $i$,则贡献为
$$
\sum_{i = 0} ^ k \binom {k} {i} \binom {a_r - v_d - 1} {i} \binom {v_u - a_r - 1} {k - i} \max((m - a_r + 1) (c_u + k - i), a_r (c_d + i))
$$
再减去不合法(使得 $a_{r + 1}\in S$)的贡献,对应第一问的容斥。
$$
\sum_{i = 0} ^ k \binom k i \binom {a_r - v_d - 1} {i} \binom {v_u - a_{r + 1} - 1} {k - i} \max((m - a_r + 1) (c_u + k - i), a_r(c_d + i))
$$
不要忘记乘以对应方案数。时间复杂度 $\mathcal{O}(n ^ 2m ^ 2)$。
- 对于 $(r, n]$:根据点对贡献的定义,位置在 $r$ 之后的元素只会和位置在 $r$ 之前的元素产生贡献。这两部分相对独立。
计算每个数在 $[1, r)$ 中出现的方案数(只安排 $a_{1\sim r - 1}$,注意 $a_r$ 和 $a_{r + 1}$ 已确定)$c_i$。枚举产生贡献的数 $a_j$(只关心数值,不关心位置),根据 $a_j$ 和 $a_r$ 的相对大小关系可知其颜色。不妨设 $a_j$ 被染成绿色($a_r > a_j$),则 $a_j$ 产生的贡献为:$a_j$ 在 $(r, n]$ 中出现的方案数(只安排 $a_{r + 1\sim n}$)$d_{a_j}$,乘以 $a_j$,再乘以 $\sum_{i = 1} ^ {a_j - 1} c_i$。
计算 $c_i$:
- 若 $i$ 在 $[1, r)$ 已经确定确定的部分出现,则 $c_i$ 等于安排 $a_{1\sim r - 1}$ 的方案数。
- 否则,对于 $i < a_r$,要求 $v_d < i < a_r$,可以认为 $i$ 在所有方案中是均匀的(感性理解),即所有 $v_d < i < a_r$ 的 $c_i$ 相等。算出 $\sum c_i$ 即
$$
\sum_{i = 0} ^ k i\binom k i \binom {a_r - v_d - 1} {i} \left(\binom {v_u - a_{r} - 1} {k - i} - \binom {v_u - a_{r + 1} - 1} {k - i} \right)
$$
对于 $i > a_r$,分别对容斥的两部分算出平均值再相减,而不是先相减再取平均。
计算 $d_{a_j}$:
- 类似地,若 $a_j$ 在 $(r, n]$ 已经确定的部分出现或等于枚举的 $a_{r + 1}$,则 $d_{a_j}$ 等于安排 $a_{r + 2\sim n}$ 的方案数。
- 否则直接对 $< x$ 和 $> y$ 的两部分分别算总数取平均(后缀不需要容斥)。用组合数算总数。
注意这里不是 $< a_r$ 和 $> a_{r + 1}$,因为当 $r < t$ 时,已经确定的数对未确定的数有更紧的界。
直接做的复杂度是 $\mathcal{O}(nm ^ 3)$,因为枚举了 $a_j$。考虑优化复杂度。
##### 优化复杂度
当 $r < t$ 时,$a_r$ 和 $a_{r + 1}$ 均不需要枚举。当 $r\geq t$ 时,$[1, a_r)$,$a_{r + 1}$ 和 $(a_{r + 1}, n]$ 的 $d$ 值分别相同,枚举太浪费时间。
- 当 $r < t$ 时,可直接暴力计算。
- 当 $r\geq t$ 时,转换视角,考虑 $c_i$ 和所有 $d_j$ 之间的贡献,后者形如 $j$ 属于某个区间范围内的 $\sum d_j j$ 或 $\sum d_j (m - j + 1)$。因 $d$ 值为三个连续段,可以 $\mathcal{O}(1)$ 计算。
在 $a_{1\sim t}$ 中出现过的 $i$(共 $t$ 个)的 $c_i$ 的贡献分别 $\mathcal{O}(1)$ 计算,而没出现过的 $c_i$ 依然为取值相同的三段区间,$\mathcal{O}(1)$ 计算一段 $c$ 值相同的区间和一段 $d$ 值相同的区间之间的贡献。
总时间复杂度 $\mathcal{O}(n ^ 2 m ^ 2)$。
一些细节:
- 在使用组合方法计算前缀或后缀方案数时,先判断已经确定的数是否合法。
- 特判 $r = n$。
- 情况 1 的第二问和情况 2,3 的第二问略有不同,可类似推导,细节处理一定要仔细,不能想当然。
- 可以对 $i < t$ 预处理 $r = i$ 时后缀是否合法及对应方案数,方便处理。
- 题解的大部分地方都只讨论了一种情况,因为另一种情况是对称的。繁琐的分类讨论导致代码量和细节很多,想清楚再开始写。
```cpp
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
bool Mbe;
constexpr int N = 50 + 5, M = 200 + 5, mod = 998244353;
void cmin(int &x, int y) {x = x < y ? x : y;}
void cmax(int &x, int y) {x = x > y ? x : y;}
void addt(int &x, int y) {x += y, x >= mod && (x -= mod);}
void subf(int &x, int y) {x -= y, x < 0 && (x += mod);}
int add(int x, int y) {return x += y, x >= mod && (x -= mod), x;}
int sub(int x, int y) {return x -= y, x < 0 && (x += mod), x;}
int S(int l, int r) {return (l + r) * (r - l + 1) / 2;}
int S2(int x) {return x * (x + 1) * (x + x + 1) / 6;}
int S2(int l, int r) {return S2(r) - S2(l - 1);}
int n, m, t, ans1, ans2, a[N], b[N], d[M], up[N], dn[N], inv[M], val1[N], val2[N], C[M][M], g[N][M][M], gi[N][M][M];
bool chk(int i, int v) {return !a[i] || a[i] == v;}
int getg(int k, int x, int y) {return x < 0 || y < 0 ? 0 : x < y ? g[k][x][y] : g[k][y][x];}
int getgi(int k, int x, int y) {return x < 0 || y < 0 ? 0 : 1ll * (x < y ? gi[k][x][y] : sub(1ll * g[k][y][x] * k % mod, gi[k][y][x])) * inv[x] % mod;} // 想一想 x > y 时为什么要将 i 变成 k - i.
int csuf(int p, int x, int y) {return p - 2 < t ? max(0, b[p - 2]) : getg(n - (p - 1), x - 1, m - y);}
struct dat {
int f, su, lo, hi;
dat operator + (const dat &z) const {return {add(f, z.f), add(su, z.su), add(lo, z.lo), add(hi, z.hi)};}
dat operator - (const dat &z) const {return {sub(f, z.f), sub(su, z.su), sub(lo, z.lo), sub(hi, z.hi)};}
dat addlo(int z) {return {f, add(su, 1ll * (m - z + 1) * hi % mod), add(lo, f), hi};}
dat addhi(int z) {return {f, add(su, 1ll * z * lo % mod), lo, add(hi, f)};}
} f[N][M][M], E;
struct itv {int l, r, val;};
void solve() {
cin >> n >> m >> t, ans1 = ans2 = 0;
memset(a, 0, sizeof(a)); for(int i = 1; i <= t; i++) cin >> a[i];
for(int i = 1; i < t; i++) {
dn[i] = min(a[i], a[i + 1]), up[i] = max(a[i], a[i + 1]);
for(int j = i + 2; j <= t; j++) a[j] > up[i] ? up[i] = a[j] : dn[i] = (a[j] < dn[i] ? a[j] : -1);
b[i] = dn[i] != -1 ? getg(n - t, dn[i] - 1, m - up[i]) : -1; // -1 表示 r 不合法.
}
memset(f, 0, sizeof(f));
for(int x = 0; x <= m; x++) for(int y = m + 1; y; y--) f[0][x][y] = {1, 0, 0, 0};
for(int i = 1; i < n; i++) { // DP 部分.
for(int v = 1; v <= m; v++) if(chk(i, v)) {
for(int y = m + 1; y > v; y--) f[i][v][y] = (f[i - 1][v - 1][y] - f[i - 1][v - 1][y + 1]).addlo(v);
for(int x = 0; x <= m; x++) f[i][x][v] = f[i][x][v] + (f[i - 1][x][v + 1] - (x ? f[i - 1][x - 1][v + 1] : E)).addhi(v);
}
for(int x = 0; x <= m; x++) for(int y = m + 1; y > x; y--) // 二维前缀和.
f[i][x][y] = f[i][x][y] + f[i][x][y + 1] + (x ? f[i][x - 1][y] - f[i][x - 1][y + 1] : E);
}
for(int r = 1; r <= n; r++) for(int v = 1; v <= m; v++) if(chk(r, v)) {
int k = max(0, r - 1 - t), cd = 0, vd = 0, cu = 0, vu = m + 1, equal = 0;
int pre = 0, suf = 0, tot = 0, sufok = r >= t || r < t && b[r] != -1;
for(int i = 1; i < r && i <= t && vd != -1 && vu != -1; i++) {
if(a[i] < v) cd++, a[i] > vd ? vd = a[i] : vd = -1;
if(a[i] == v) equal++;
if(a[i] > v) cu++, a[i] < vu ? vu = a[i] : vu = -1;
}
if(vu == -1 || vd == -1 || equal > 1 || equal == 1 && a[r - 1] != v) continue; // 前缀不合法, 注意情况 1 的 equal 可以等于 1.
auto calc = [&](int v2) { // (r, n], 被调用 O(nm ^ 2) 次.
if(r == 1 || r == n) return; // r = 1 或 r = n 一定不产生贡献.
if(r < t) { // 总共调用 O(n) 次, 随便暴力.
memset(d, 0, sizeof(d));
if(v < v2 && vu > v2 || v > v2 && vd < v2) return; // 判 r + 1 \in S 不合法.
int hi = getgi(n - t, m - up[r], dn[r] - 1), lo = getgi(n - t, dn[r] - 1, m - up[r]);
for(int i = up[r] + 1; i <= m; i++) d[i] = 1ll * hi * (m - i + 1) % mod;
for(int i = dn[r] - 1; i >= 1; i--) d[i] = 1ll * lo * i % mod;
for(int i = r + 1; i <= t; i++) addt(d[a[i]], 1ll * suf * (a[i] < v ? a[i] : m - a[i] + 1) % mod);
for(int i = max(v, v2) + 1; i <= m; i++) addt(d[i], d[i - 1]);
for(int i = min(v, v2) - 1; i; i--) addt(d[i], d[i + 1]);
for(int i = 1; i < r - (v == v2); i++) addt(ans2, a[i] < v ? d[a[i] + 1] : d[a[i] - 1]); // 已经确定的数的 c = 1.
}
else {
static itv c[3], d[3]; int C = 0, D = 0;
if(v != v2) {
d[D++] = {1, min(v, v2) - 1, getgi(n - (r + 1), min(v, v2) - 1, m - max(v, v2))};
d[D++] = {max(v, v2) + 1, m, getgi(n - (r + 1), m - max(v, v2), min(v, v2) - 1)};
d[D++] = {v2, v2, suf};
if(v < v2) {
if(vd + 1 < v) c[C++] = {vd + 1, v - 1, sub(getgi(k, v - vd - 1, vu - v - 1), getgi(k, v - vd - 1, vu - v2 - 1))};
if(v2 + 1 < vu) c[C++] = {v2 + 1, vu - 1, sub(getgi(k, vu - v - 1, v - vd - 1), getgi(k, vu - v2 - 1, v - vd - 1))};
}
else {
if(v + 1 < vu) c[C++] = {v + 1, vu - 1, sub(getgi(k, vu - v - 1, v - vd - 1), getgi(k, vu - v - 1, v2 - vd - 1))};
if(vd + 1 < v2) c[C++] = {vd + 1, v2 - 1, sub(getgi(k, v - vd - 1, vu - v - 1), getgi(k, v2 - vd - 1, vu - v - 1))};
}
}
else {
d[D++] = {1, v - 1, getgi(n - r, v - 1, m - v)};
d[D++] = {v + 1, m, getgi(n - r, m - v, v - 1)};
if(vd + 1 < v && k) c[C++] = {vd + 1, v - 1, getgi(k - 1, v - vd - 1, vu - v - 1)}; // 这里是 k - 1!
if(v + 1 < vu && k) c[C++] = {v + 1, vu - 1, getgi(k - 1, vu - v - 1, v - vd - 1)};
}
int tot = 0;
for(int i = 1; i < r - (v == v2) && i <= t; i++) for(int j = 0; j < D; j++) { // 计算已经确定的 c_i 的贡献.
const itv &I = d[j];
if(a[i] < v && I.l < v && I.r > a[i]) addt(tot, 1ll * S(max(I.l, a[i] + 1), I.r) * I.val % mod);
if(a[i] > v && I.l > v && I.l < a[i]) addt(tot, 1ll * S(m - min(I.r, a[i] - 1) + 1, m - I.l + 1) * I.val % mod);
}
addt(ans2, 1ll * tot * pre % mod);
for(int i = 0; i < C; i++) for(int j = 0; j < D; j++) {
const itv &I = c[i], &J = d[j]; int coef = 0;
if(I.l < v && J.l < v) { // c 和 d 的区间, 要么后者包含前者, 要么后者为 [v2, v2]
if(J.l <= I.l && I.r <= J.r) coef += (I.r - I.l + 1) * J.r * (J.r + 1) - S2(I.l, I.r) - S(I.l, I.r) >> 1;
else if(I.r < J.l) coef += (I.r - I.l + 1) * J.l; // J.l = J.r = v2
}
if(I.l > v && J.l > v) {
if(J.l <= I.l && I.r <= J.r) coef += S(I.l - J.l, I.r - J.l) * (m + 1) - (S2(I.l, I.r) - S(I.l, I.r) - (I.r - I.l + 1) * (J.l - 1) * J.l >> 1);
else if(J.r < I.l) coef += (I.r - I.l + 1) * (m - J.l + 1);
}
addt(ans2, 1ll * coef * I.val % mod * J.val % mod);
}
}
};
if(r > 1 && chk(r - 1, v) && sufok) {
suf = csuf(r + 1, v, v);
addt(ans1, 1ll * (pre = f[r - 2][v - 1][v + 1].f) * suf % mod);
addt(ans2, 1ll * f[r - 2][v - 1][v + 1].addlo(v).su * suf % mod); // [1, r)
addt(ans2, 1ll * f[r - 2][v - 1][v + 1].lo * v % mod * suf % mod); // r
calc(v);
}
if(equal) continue; // 情况 2 和 3 要求前缀没有等于 a[r] 的元素.
for(int i = 0, val; i <= k; i++) { // 预处理经常要用的式子, 减小常数.
val = 1ll * C[k][i] * max((m - v + 1) * (cu + k - i), v * (cd + i)) % mod;
val1[i] = 1ll * val * C[v - vd - 1][i] % mod, val2[i] = 1ll * val * C[vu - v - 1][k - i] % mod;
}
if(r == n) {
addt(ans1, f[r - 1][v - 1][v + 1].f);
addt(ans2, f[r - 1][v - 1][v + 1].su); // [1, r)
for(int i = 0; i <= k; i++) if(i < v - vd && k - i < vu - v) // r
addt(ans2, 1ll * val1[i] * C[vu - v - 1][k - i] % mod);
}
else if(sufok) {
for(int v2 = v + 1; v2 <= m; v2++) if(chk(r + 1, v2)) {
suf = csuf(r + 2, v, v2), tot = 0;
addt(ans1, 1ll * (pre = (f[r - 1][v - 1][v + 1] - f[r - 1][v - 1][v2 + 1]).f) * suf % mod);
addt(tot, (f[r - 1][v - 1][v + 1] - f[r - 1][v - 1][v2 + 1]).su); // [1, r)
for(int i = 0; i <= k; i++) if(i < v - vd && k - i < vu - v) // r
addt(tot, 1ll * val1[i] * sub(C[vu - v - 1][k - i], k - i < vu - v2 ? C[vu - v2 - 1][k - i] : 0) % mod);
calc(v2), addt(ans2, 1ll * tot * suf % mod);
}
for(int v2 = 1; v2 < v; v2++) if(chk(r + 1, v2)) {
suf = csuf(r + 2, v2, v), tot = 0;
addt(ans1, 1ll * (pre = (f[r - 1][v - 1][v + 1] - f[r - 1][v2 - 1][v + 1]).f) * suf % mod);
addt(tot, (f[r - 1][v - 1][v + 1] - f[r - 1][v2 - 1][v + 1]).su); // [1, r)
for(int i = 0; i <= k; i++) if(i < v - vd && k - i < vu - v) // r
addt(tot, 1ll * val2[i] * sub(C[v - vd - 1][i], i < v2 - vd ? C[v2 - vd - 1][i] : 0) % mod);
calc(v2), addt(ans2, 1ll * tot * suf % mod);
}
}
}
cout << ans1 << " " << ans2 << "\n";
}
bool Med;
int main() {
fprintf(stderr, "%.4lf\n", (&Mbe - &Med) / 1048576.0);
#ifdef ALEX_WEI
FILE* IN = freopen("color.in", "r", stdin);
FILE* OUT = freopen("color.out", "w", stdout);
#endif
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
inv[1] = 1; for(int i = 2; i < M; i++) inv[i] = mod - 1ll * mod / i * inv[mod % i] % mod;
for(int i = 0; i < M; i++) for(int j = 0; j <= i; j++) C[i][j] = j ? add(C[i - 1][j - 1], C[i - 1][j]) : 1;
for(int k = 0; k < N; k++) for(int x = 0; x < M; x++)
for(int y = x; y + x < M; y++) for(int i = max(0, k - y), val; i <= min(x, k); i++)
addt(g[k][x][y], val = 1ll * C[k][i] * C[x][i] % mod * C[y][k - i] % mod), addt(gi[k][x][y], 1ll * i * val % mod);
int T; cin >> T; while(T--) solve();
cerr << 1e3 * clock() / CLOCKS_PER_SEC << " ms\n";
return 0;
}
/*
g++ color.cpp -o color -std=c++14 -O2 -DALEX_WEI
v = v2 的 k -> k - 1.
加入判断 equal = 1 && a[r - 1] != v.
将 S(I.l - J.l, I.r - J.l) * m 改成 S(I.l - J.l, I.r - J.l) * (m + 1).
*/
```