Luogu P13272 [NOI 2025] 序列变换
rizynvu
·
·
题解
cnblogs。
首先因为这题要计数,所以尝试去找一个刻画方式刻画出所有能被 a 生成的序列。
考虑操作实际就是让 a_i, a_{i + 1} 同时减去 \min\{a_i, a_{i + 1}\}。
在这之后一定会有一个数变为 0,且发现若操作的两个数中有 0,那么情况一定不变。
于是可以知道,如果多次选择了 (i, i + 1),那么除第一次的操作肯定都是无效的。
那么直接指定每对 (i, i + 1) 至多被操作一次即可。
考虑操作时若 a_i \le a_{i + 1},那么认为有 i\to i + 1;若 a_i \ge a_{i + 1},那么有 i\gets i + 1(相等会有点问题,但是这里先暂时不考虑)。
这个边集一定能被分成若干个 \cdots\to\to\to\gets\gets\gets\cdots 的连续段,即左端从 l,右端从 r 往内汇聚,直到一个分界线 k 挡住了两测。
于是可以首先预处理出单侧的情况,记 lv_{l, k} 表示 l 汇聚到 k 时的值,易知有 lv_{l, l} = a_l, lv_{l, k + 1} = a_{k + 1} - lv_{l, k}。
需要保证中途都是可以操作的,即若 lv_{l, k} < 0,那么其实都走不到 k 这里,从 k 开始的 lv_{l, k} 都是不合法的。
同理定义 rv_{r, k}。
那么考虑暴力的做法,直接枚举 (l, r, k),首先肯定要满足 l, r 都能走到 k,然后再来分讨一下 k 处的情况:
(这里与 a_k 比较只是因为前面定义的 lv_{l, k}, rv_{r, k} 相加时会多一个 a_k,实际上还是与 0 比较。)
于是现在就有了个 \mathcal{O}(n^3) 的做法,不过会发现这个做法其实有点问题,计数的时候可能会计重。
例如 a = [1, 1, 1, 1],会认为 [1, 2], [3, 4], [1, 4] 都可以消除,那这就爆了。
思考一下原因,对于 (l, r, k),其实如果删除一段前缀或后继(不包括 k),假设通过这段后得到的值是 x,那么对最终 a_k 的影响肯定是有 +x 或 -x 的,那么肯定就需要另一边同样调整以使最终 a_k 不变,但是这样又会影响周边的段继续修改,所以一定会继续往外扩张的,一定不合法——吗?
于是发现了问题:当 x = 0 时,且如果换一个方向传入的值也是 0 时,那么这一段是可以分到任意一边的。
为了解决这个问题,只需要考虑小修一下 lv_{l, k}, rv_{r, k} 的定义:当 \le 0 时就是不合法的。这样的 \mathcal{O}(n^3) 做法就是正确的。
接下来考虑继续优化,发现 dp 中直接对于 $r$ 枚举 $l$ 并统计贡献看着已经比较优了,于是需要解决的应当是快速处理对于所有 $k$,$(l, r, k)$ 的贡献。
上文已经发现了一个事实:对于一个数 $x$,其对中心的贡献一定是 $+x$ 或 $-x$。
从这里入手,能发现符号一定是与中心距离为偶数时为 $+$,为奇数时为 $-$。即包括中心的右侧符号应当形如 $+-+-+-+-\cdots$,左侧也是对称的。
于是对于 $l, r$ 来说,若 $k$ 的奇偶性相同,则 $a_k$ 最后得到的值一定也是相同的,且不同奇偶性的 $k$ 得到的值一定互为相反数(不过不知道这个也没有啥)。
记 $l$ 能扩展到的最远的 $k$ 为 $lb_l$,同理记 $rb_r$。
那么合法的 $k$ 一定是在 $[\max(l, rb_r), \min(r, lb_l)]$ 中的所有奇数或偶数或全部(中间值 $= 0$ 时),只需要知道这部分的 $\max \{-b_k\}$ 和 $\sum \frac{1}{c_k}$,可以写个 st 表前缀和,不过因为数据范围并不大,直接 $\mathcal{O}(n^2)$ 预处理也可以。
一个小细节:上述判断中间值 $= 0$ 的方法还是有点问题,不过能发现问题只出在 $a_i = a_{i + 1}$,特殊处理一下即可。
代码是复刻的,应该没有问题。
$\mathcal{O}(n^3)$(判断 $=0$ 的方式有点不一样):
```cpp
inline void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
for (int i = 1; i <= n; i++) {
memset(lv[i] + 1, -1, sizeof(int) * n);
lv[i][i] = a[i];
for (int j = i; j < n && lv[i][j] < a[j + 1]; j++) {
lv[i][j + 1] = a[j + 1] - lv[i][j];
}
}
for (int i = n; i >= 1; i--) {
memset(rv[i] + 1, -1, sizeof(int) * n);
rv[i][i] = a[i];
for (int j = i; j > 1 && rv[i][j] < a[j - 1]; j--) {
rv[i][j - 1] = a[j - 1] - rv[i][j];
}
}
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
ok0[l][r] = false;
}
}
for (int i = 1; i < n; i++) {
for (int l = 1; l <= i; l++) {
for (int r = i + 1; r <= n; r++) {
ok0[l][r] |= lv[l][i] != -1 && rv[r][i + 1] != -1 && lv[l][i] == rv[r][i + 1];
}
}
}
for (int i = 1; i <= n; i++) {
for (int l = 1; l <= i; l++) {
for (int r = i; r <= n; r++) {
ok[i][l][r] = lv[l][i] != -1 && rv[r][i] != -1 && lv[l][i] + rv[r][i] > a[i];
}
}
}
pre[0] = 0;
for (int i = 1; i <= n; i++) pre[i] = pre[i - 1] + b[i];
f[0] = 0;
for (int i = 1; i <= n; i++) {
f[i] = f[i - 1];
for (int j = 1; j <= i; j++) {
if (ok0[j][i]) f[i] = std::max(f[i], f[j - 1] + pre[i] - pre[j - 1]);
for (int k = j; k <= i; k++) {
if (ok[k][j][i]) f[i] = std::max(f[i], f[j - 1] + pre[i] - pre[j - 1] - b[k]);
}
}
}
pr[0] = ipr[0] = 1;
for (int i = 1; i <= n; i++) {
ic[i] = qpow(c[i], mod - 2);
pr[i] = pr[i - 1] * c[i] % mod;
ipr[i] = ipr[i - 1] * ic[i] % mod;
}
g[0] = 1;
for (int i = 1; i <= n; i++) {
g[i] = 0;
for (int j = 1; j <= i; j++) {
if (ok0[j][i]) g[i] = (g[i] + g[j - 1] * pr[i] % mod * ipr[j - 1]) % mod;
for (int k = j; k <= i; k++) {
if (ok[k][j][i]) g[i] = (g[i] + g[j - 1] * pr[i] % mod * ipr[j - 1] % mod * ic[k]) % mod;
}
}
}
printf("%lld %lld\n", f[n], g[n]);
}
```
$\mathcal{O}(n^2)$:
```cpp
#include <bits/stdc++.h>
using ll = long long;
constexpr ll mod = 1e9 + 7;
inline ll qpow(ll a, ll b) {
ll v = 1;
for (; b; b >>= 1, a = a * a % mod) {
if (b & 1) v = v * a % mod;
}
return v;
}
constexpr int maxn = 5000 + 10;
int n;
int a[maxn], b[maxn], c[maxn];
int lb[maxn], rb[maxn];
bool ok0[maxn][maxn];
ll preb[maxn], fv[maxn][maxn], f[maxn];
ll prec[maxn], iprec[maxn], ic[maxn], gv[maxn][maxn], g[maxn];
ll prea[maxn];
int pic[maxn][maxn][2], mxb[maxn][maxn][2];
inline void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
prea[0] = 0;
for (int i = 1; i <= n; i++) {
prea[i] = prea[i - 1] + (i % 2 ? a[i] : -a[i]);
}
preb[0] = 0, prec[0] = iprec[0] = 1;
for (int i = 1; i <= n; i++) {
ic[i] = qpow(c[i], mod - 2);
preb[i] = preb[i - 1] + b[i];
prec[i] = prec[i - 1] * c[i] % mod;
iprec[i] = iprec[i - 1] * ic[i] % mod;
}
for (int l = 1; l <= n; l++) {
pic[l][l - 1][0] = pic[l][l - 1][1] = 0;
mxb[l][l - 1][0] = mxb[l][l - 1][1] = -2e9;
for (int r = l; r <= n; r++) {
pic[l][r][0] = pic[l][r - 1][0];
pic[l][r][1] = pic[l][r - 1][1];
mxb[l][r][0] = mxb[l][r - 1][0];
mxb[l][r][1] = mxb[l][r - 1][1];
pic[l][r][r % 2] = (pic[l][r][r % 2] + ic[r]) % mod;
mxb[l][r][r % 2] = std::max(mxb[l][r][r % 2], -b[r]);
}
}
for (int i = 1; i <= n; i++) {
lb[i] = i;
for (int x = a[i]; lb[i] < n && x < a[lb[i] + 1]; ) {
x = a[++lb[i]] - x;
}
}
for (int i = n; i >= 1; i--) {
rb[i] = i;
for (int x = a[i]; rb[i] > 1 && x < a[rb[i] - 1]; ) {
x = a[--rb[i]] - x;
}
}
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
fv[l][r] = -1e18, gv[l][r] = 0;
}
}
auto conv = [&](const int x) {
return x == -2e9 ? (ll)-1e18 : (ll)x;
};
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
if (prea[r] == prea[l - 1] || lb[l] < rb[r]) continue;
const int op = prea[r] > prea[l - 1];
const int st = std::max(rb[r], l);
const int ed = std::min(lb[l], r);
fv[l][r] = std::max(fv[l][r], preb[r] - preb[l - 1] + conv(mxb[st][ed][op]));
gv[l][r] = (gv[l][r] + prec[r] * iprec[l - 1] % mod * pic[st][ed][op]) % mod;
}
}
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
ok0[l][r] = prea[r] - prea[l - 1] == 0 && rb[r] <= lb[l];
}
}
for (int i = 1; i < n; i++) {
ok0[i][i + 1] |= a[i] == a[i + 1];
}
for (int l = 1; l <= n; l++) {
for (int r = l; r <= n; r++) {
if (ok0[l][r]) {
fv[l][r] = std::max(fv[l][r], preb[r] - preb[l - 1]);
gv[l][r] = (gv[l][r] + prec[r] * iprec[l - 1]) % mod;
}
}
}
f[0] = 0;
for (int i = 1; i <= n; i++) {
f[i] = -1e18;
for (int j = 1; j <= i; j++) {
f[i] = std::max(f[i], f[j - 1] + fv[j][i]);
}
}
g[0] = 1;
for (int i = 1; i <= n; i++) {
g[i] = 0;
for (int j = 1; j <= i; j++) {
g[i] = (g[i] + g[j - 1] * gv[j][i]) % mod;
}
}
printf("%lld %lld\n", f[n], g[n]);
}
int main() {
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
int testid, t;
scanf("%d%d", &testid, &t);
while (t--) solve();
return 0;
}
```