题解 CF1668D【Optimal Partition】
Anguei
2022-04-20 16:04:25
## 题意
给定一个数组 $a$,长度为 $n (1 \leq n \leq 5 \times 10^5)$,你需要将其分割为若干个连续的子数组,使所有子数组的*价值*总和最大。
定义 $\texttt{s(l, r)} = a_l + a_{l+1} + a_{l+2} + \cdots + a_r$,子数组 $a[l, r]$ 的*价值*是:
- $r-l+1$,$\texttt{s(l, r)} > 0$
- 0,$\texttt{s(l, r)} = 0$
- $-(r-l+1)$,$\texttt{s(l, r)} < 0$
## 思路
有意思的一道题目。
### 暴力
若忽略掉数据范围的限制,不难想到一种 $O(n^2)$ 的 dp 方案:
- 先无脑地把前缀和求出来;
- 设 $dp[i]$ 表示前 $i$ 个元素构成的子数组,划分后可产生的最大*价值*;
- 显然 $dp[0] = 0$,即空数组不产生*价值*;
- 考虑转移:枚举 $i, j (j < i)$,则 $dp[i] = \max\left( dp[j] + \texttt{calc(j + 1, i)} \right)$;
- 其中 $\texttt{calc(j + 1, i)}$ 表示子数组 $a[j + 1, r]$ 的*价值*,可以由定义计算;
- $dp[n]$ 就是答案。
```cpp
dp[0] = 0;
for (int i = 1; i <= n; ++i)
for (int j = 0; j < i; ++j)
dp[i] = std::max(dp[i], dp[j] + calc(j + 1, i));
print(dp[n]);
```
### 优化
上面的转移方程 $dp[i] = \max\left( dp[j] + \texttt{calc(j + 1, i)} \right)$ 包含了函数调用,有点麻烦,不妨把它拆开,于是产生了三个新的方程:
- $dp[i] = \max\left( dp[j] + (i-j) \right)$,$\texttt{sum(j + 1, i)} > 0$
- $dp[i] = \max\left( dp[j] \right)$,$\texttt{sum(j + 1, i)} = 0$
- $dp[i] = \max\left( dp[j] + (j-i) \right)$,$\texttt{sum(j + 1, i)} < 0$
而刚刚频繁出现的 $\texttt{sum(j + 1, i)}$ 又可以拆成 $s[i] - s[(j+1)-1] = s[i] - s[j]$。于是上面三个式子又可以移项、变形为:
- $dp[i] = \max\left( dp[j] - j \right) + i$,$s[i] > s[j]$
- $dp[i] = \max\left( dp[j] \right)$,$s[i] = s[j]$
- $dp[i] = \max\left( dp[j] + j \right) - i$,$s[i] < s[j]$
于是只需要动态维护 $dp[i] + i$,$dp[i]$,$dp[i] - i$ 的区间最大值即可。
首先对前缀和数组 $s$ 进行离散化,然后把 $s[j]$ 当成下标。每次查询坐标 $s[j]$ 之前(或之后)的区间最大值。当然,对于 $s[i] = s[j]$ 的情况,相当于查询 $s[j]$ 坐标的最大值。
建立三棵线段树,分别维护即可:第零棵维护 $dp[i] + i$,第一棵维护 $dp[i]$,第二棵维护 $dp[i] - i$。dp 数组更新完之后,再进入线段树修改区间最大值(单点修改)。
~~(又是三棵树,仿佛嗅到了 [CF1660F2](https://www.luogu.com.cn/problem/CF1660F2) 的味道)~~
## 代码
```cpp
struct SegTree {
std::vector<ll> a;
SegTree(int n) : a((n + 1) * 4, 0) { this->build(1, 1, n); }
#define lson ((o) << 1)
#define rson ((o) << 1 | 1)
void build(int o, int l, int r) {
if (l == r) return void(a[o] = -1e18);
int mid = (l + r) / 2;
build(lson, l, mid);
build(rson, mid + 1, r);
a[o] = std::max(a[lson], a[rson]);
}
ll query(int o, int l, int r, int L, int R) {
if (l >= L && r <= R) return a[o];
int mid = (l + r) / 2;
ll res = -1e18;
if (L <= mid) res = std::max(res, query(lson, l, mid, L, R));
if (R > mid) res = std::max(res, query(rson, mid + 1, r, L, R));
return res;
}
void change(int o, int l, int r, int x, ll val) {
if (l == r) return void(a[o] = std::max(a[o], val));
int mid = (l + r) / 2;
if (x <= mid)
change(lson, l, mid, x, val);
else
change(rson, mid + 1, r, x, val);
a[o] = std::max(a[lson], a[rson]);
}
#undef lson
#undef rson
};
void solution() {
int n;
read(n);
std::vector<int> a(n + 1);
for (int i = 1; i <= n; ++i) read(a[i]);
std::vector<ll> s(n + 1);
for (int i = 1; i <= n; ++i) s[i] = s[i - 1] + a[i];
// 离散化前缀和数组 s
std::vector<ll> vs(s.begin(), s.end());
std::sort(vs.begin(), vs.end());
std::map<ll, int> belong;
int tot = 0;
for (auto i : vs)
if (!belong.count(i)) belong[i] = ++tot;
// s2 是离散化后的 s
std::vector<int> s2(n + 1);
for (int i = 1; i <= n; ++i) s2[i] = belong[s[i]];
auto chmax = [](auto& x, auto y) { x = std::max(x, y); };
std::vector<SegTree> seg(3, SegTree(tot));
// 下面这行相当于暴力代码的 dp[0] = 0
for (int i = 0; i < 3; ++i) seg[i].change(1, 1, tot, belong[0], 0);
for (int i = 1; i <= n; i++) {
// 对应上述第一个转移方程
if (s2[i] > 1) chmax(dp[i], seg[2].query(1, 1, tot, 1, s2[i] - 1) + i);
// 第二个转移方程
chmax(dp[i], seg[1].query(1, 1, tot, s2[i], s2[i]));
// 第三个转移方程
if (s2[i] < tot) chmax(dp[i], seg[0].query(1, 1, tot, s2[i] + 1, tot) - i);
// 单点修改
seg[0].change(1, 1, tot, s2[i], dp[i] + i);
seg[1].change(1, 1, tot, s2[i], dp[i]);
seg[2].change(1, 1, tot, s2[i], dp[i] - i);
}
print(dp[n]);
}
```