Luogu P13272 [NOI 2025] 序列变换

· · 题解

cnblogs。

首先因为这题要计数,所以尝试去找一个刻画方式刻画出所有能被 a 生成的序列。

考虑操作实际就是让 a_i, a_{i + 1} 同时减去 \min\{a_i, a_{i + 1}\}

在这之后一定会有一个数变为 0,且发现若操作的两个数中有 0,那么情况一定不变。

于是可以知道,如果多次选择了 (i, i + 1),那么除第一次的操作肯定都是无效的。
那么直接指定每对 (i, i + 1) 至多被操作一次即可。

考虑操作时若 a_i \le a_{i + 1},那么认为有 i\to i + 1;若 a_i \ge a_{i + 1},那么有 i\gets i + 1(相等会有点问题,但是这里先暂时不考虑)。
这个边集一定能被分成若干个 \cdots\to\to\to\gets\gets\gets\cdots 的连续段,即左端从 l,右端从 r 往内汇聚,直到一个分界线 k 挡住了两测。

于是可以首先预处理出单侧的情况,记 lv_{l, k} 表示 l 汇聚到 k 时的值,易知有 lv_{l, l} = a_l, lv_{l, k + 1} = a_{k + 1} - lv_{l, k}
需要保证中途都是可以操作的,即若 lv_{l, k} < 0,那么其实都走不到 k 这里,从 k 开始的 lv_{l, k} 都是不合法的。
同理定义 rv_{r, k}

那么考虑暴力的做法,直接枚举 (l, r, k),首先肯定要满足 l, r 都能走到 k,然后再来分讨一下 k 处的情况:

(这里与 a_k 比较只是因为前面定义的 lv_{l, k}, rv_{r, k} 相加时会多一个 a_k,实际上还是与 0 比较。)

于是现在就有了个 \mathcal{O}(n^3) 的做法,不过会发现这个做法其实有点问题,计数的时候可能会计重。

例如 a = [1, 1, 1, 1],会认为 [1, 2], [3, 4], [1, 4] 都可以消除,那这就爆了。
思考一下原因,对于 (l, r, k),其实如果删除一段前缀或后继(不包括 k),假设通过这段后得到的值是 x,那么对最终 a_k 的影响肯定是有 +x-x 的,那么肯定就需要另一边同样调整以使最终 a_k 不变,但是这样又会影响周边的段继续修改,所以一定会继续往外扩张的,一定不合法——吗?
于是发现了问题:当 x = 0 时,且如果换一个方向传入的值也是 0 时,那么这一段是可以分到任意一边的。

为了解决这个问题,只需要考虑小修一下 lv_{l, k}, rv_{r, k} 的定义:当 \le 0 时就是不合法的。这样的 \mathcal{O}(n^3) 做法就是正确的。

接下来考虑继续优化,发现 dp 中直接对于 $r$ 枚举 $l$ 并统计贡献看着已经比较优了,于是需要解决的应当是快速处理对于所有 $k$,$(l, r, k)$ 的贡献。 上文已经发现了一个事实:对于一个数 $x$,其对中心的贡献一定是 $+x$ 或 $-x$。 从这里入手,能发现符号一定是与中心距离为偶数时为 $+$,为奇数时为 $-$。即包括中心的右侧符号应当形如 $+-+-+-+-\cdots$,左侧也是对称的。 于是对于 $l, r$ 来说,若 $k$ 的奇偶性相同,则 $a_k$ 最后得到的值一定也是相同的,且不同奇偶性的 $k$ 得到的值一定互为相反数(不过不知道这个也没有啥)。 记 $l$ 能扩展到的最远的 $k$ 为 $lb_l$,同理记 $rb_r$。 那么合法的 $k$ 一定是在 $[\max(l, rb_r), \min(r, lb_l)]$ 中的所有奇数或偶数或全部(中间值 $= 0$ 时),只需要知道这部分的 $\max \{-b_k\}$ 和 $\sum \frac{1}{c_k}$,可以写个 st 表前缀和,不过因为数据范围并不大,直接 $\mathcal{O}(n^2)$ 预处理也可以。 一个小细节:上述判断中间值 $= 0$ 的方法还是有点问题,不过能发现问题只出在 $a_i = a_{i + 1}$,特殊处理一下即可。 代码是复刻的,应该没有问题。 $\mathcal{O}(n^3)$(判断 $=0$ 的方式有点不一样): ```cpp inline void solve() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1; i <= n; i++) scanf("%d", &b[i]); for (int i = 1; i <= n; i++) scanf("%d", &c[i]); for (int i = 1; i <= n; i++) { memset(lv[i] + 1, -1, sizeof(int) * n); lv[i][i] = a[i]; for (int j = i; j < n && lv[i][j] < a[j + 1]; j++) { lv[i][j + 1] = a[j + 1] - lv[i][j]; } } for (int i = n; i >= 1; i--) { memset(rv[i] + 1, -1, sizeof(int) * n); rv[i][i] = a[i]; for (int j = i; j > 1 && rv[i][j] < a[j - 1]; j--) { rv[i][j - 1] = a[j - 1] - rv[i][j]; } } for (int l = 1; l <= n; l++) { for (int r = l; r <= n; r++) { ok0[l][r] = false; } } for (int i = 1; i < n; i++) { for (int l = 1; l <= i; l++) { for (int r = i + 1; r <= n; r++) { ok0[l][r] |= lv[l][i] != -1 && rv[r][i + 1] != -1 && lv[l][i] == rv[r][i + 1]; } } } for (int i = 1; i <= n; i++) { for (int l = 1; l <= i; l++) { for (int r = i; r <= n; r++) { ok[i][l][r] = lv[l][i] != -1 && rv[r][i] != -1 && lv[l][i] + rv[r][i] > a[i]; } } } pre[0] = 0; for (int i = 1; i <= n; i++) pre[i] = pre[i - 1] + b[i]; f[0] = 0; for (int i = 1; i <= n; i++) { f[i] = f[i - 1]; for (int j = 1; j <= i; j++) { if (ok0[j][i]) f[i] = std::max(f[i], f[j - 1] + pre[i] - pre[j - 1]); for (int k = j; k <= i; k++) { if (ok[k][j][i]) f[i] = std::max(f[i], f[j - 1] + pre[i] - pre[j - 1] - b[k]); } } } pr[0] = ipr[0] = 1; for (int i = 1; i <= n; i++) { ic[i] = qpow(c[i], mod - 2); pr[i] = pr[i - 1] * c[i] % mod; ipr[i] = ipr[i - 1] * ic[i] % mod; } g[0] = 1; for (int i = 1; i <= n; i++) { g[i] = 0; for (int j = 1; j <= i; j++) { if (ok0[j][i]) g[i] = (g[i] + g[j - 1] * pr[i] % mod * ipr[j - 1]) % mod; for (int k = j; k <= i; k++) { if (ok[k][j][i]) g[i] = (g[i] + g[j - 1] * pr[i] % mod * ipr[j - 1] % mod * ic[k]) % mod; } } } printf("%lld %lld\n", f[n], g[n]); } ``` $\mathcal{O}(n^2)$: ```cpp #include <bits/stdc++.h> using ll = long long; constexpr ll mod = 1e9 + 7; inline ll qpow(ll a, ll b) { ll v = 1; for (; b; b >>= 1, a = a * a % mod) { if (b & 1) v = v * a % mod; } return v; } constexpr int maxn = 5000 + 10; int n; int a[maxn], b[maxn], c[maxn]; int lb[maxn], rb[maxn]; bool ok0[maxn][maxn]; ll preb[maxn], fv[maxn][maxn], f[maxn]; ll prec[maxn], iprec[maxn], ic[maxn], gv[maxn][maxn], g[maxn]; ll prea[maxn]; int pic[maxn][maxn][2], mxb[maxn][maxn][2]; inline void solve() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1; i <= n; i++) scanf("%d", &b[i]); for (int i = 1; i <= n; i++) scanf("%d", &c[i]); prea[0] = 0; for (int i = 1; i <= n; i++) { prea[i] = prea[i - 1] + (i % 2 ? a[i] : -a[i]); } preb[0] = 0, prec[0] = iprec[0] = 1; for (int i = 1; i <= n; i++) { ic[i] = qpow(c[i], mod - 2); preb[i] = preb[i - 1] + b[i]; prec[i] = prec[i - 1] * c[i] % mod; iprec[i] = iprec[i - 1] * ic[i] % mod; } for (int l = 1; l <= n; l++) { pic[l][l - 1][0] = pic[l][l - 1][1] = 0; mxb[l][l - 1][0] = mxb[l][l - 1][1] = -2e9; for (int r = l; r <= n; r++) { pic[l][r][0] = pic[l][r - 1][0]; pic[l][r][1] = pic[l][r - 1][1]; mxb[l][r][0] = mxb[l][r - 1][0]; mxb[l][r][1] = mxb[l][r - 1][1]; pic[l][r][r % 2] = (pic[l][r][r % 2] + ic[r]) % mod; mxb[l][r][r % 2] = std::max(mxb[l][r][r % 2], -b[r]); } } for (int i = 1; i <= n; i++) { lb[i] = i; for (int x = a[i]; lb[i] < n && x < a[lb[i] + 1]; ) { x = a[++lb[i]] - x; } } for (int i = n; i >= 1; i--) { rb[i] = i; for (int x = a[i]; rb[i] > 1 && x < a[rb[i] - 1]; ) { x = a[--rb[i]] - x; } } for (int l = 1; l <= n; l++) { for (int r = l; r <= n; r++) { fv[l][r] = -1e18, gv[l][r] = 0; } } auto conv = [&](const int x) { return x == -2e9 ? (ll)-1e18 : (ll)x; }; for (int l = 1; l <= n; l++) { for (int r = l; r <= n; r++) { if (prea[r] == prea[l - 1] || lb[l] < rb[r]) continue; const int op = prea[r] > prea[l - 1]; const int st = std::max(rb[r], l); const int ed = std::min(lb[l], r); fv[l][r] = std::max(fv[l][r], preb[r] - preb[l - 1] + conv(mxb[st][ed][op])); gv[l][r] = (gv[l][r] + prec[r] * iprec[l - 1] % mod * pic[st][ed][op]) % mod; } } for (int l = 1; l <= n; l++) { for (int r = l; r <= n; r++) { ok0[l][r] = prea[r] - prea[l - 1] == 0 && rb[r] <= lb[l]; } } for (int i = 1; i < n; i++) { ok0[i][i + 1] |= a[i] == a[i + 1]; } for (int l = 1; l <= n; l++) { for (int r = l; r <= n; r++) { if (ok0[l][r]) { fv[l][r] = std::max(fv[l][r], preb[r] - preb[l - 1]); gv[l][r] = (gv[l][r] + prec[r] * iprec[l - 1]) % mod; } } } f[0] = 0; for (int i = 1; i <= n; i++) { f[i] = -1e18; for (int j = 1; j <= i; j++) { f[i] = std::max(f[i], f[j - 1] + fv[j][i]); } } g[0] = 1; for (int i = 1; i <= n; i++) { g[i] = 0; for (int j = 1; j <= i; j++) { g[i] = (g[i] + g[j - 1] * gv[j][i]) % mod; } } printf("%lld %lld\n", f[n], g[n]); } int main() { freopen("sequence.in", "r", stdin); freopen("sequence.out", "w", stdout); int testid, t; scanf("%d%d", &testid, &t); while (t--) solve(); return 0; } ```