在线决策单调性地皮还能单老哥分治做?
Register_int
·
·
算法·理论
原博客: 簡易版 LARSCH Algorithm noshi91。
众所周知决策单调性 dp 有很多做法,大家应该都会。但是主播主播,二分队列遇上要莫队算的贡献还是太菜了,整体二分还是太吃离线了,SMAWK 算法和 Wilber 算法又太吃常数和码量了。那么有没有什么常数小、能处理莫队、还好写的决策单调性优化 dp 呢?
有的兄弟有的,这就是我们要介绍的:
简易版 LARSCH 算法
设 dp 数组为 f。考虑分治。尝试设计一个 \operatorname{solve}(l,r) 函数。我们希望运行完这个函数后,f_{l+1\sim r} 的值都已经算对了。但是考虑分治中心 mid,我们如果要求 f_{mid} 一定要算对,那么 0\sim mid-1 的东西都要算对,这显然是不现实的。所以我们需要分成若干段考虑贡献。
具体地,我们在算 mid 时,不要求所有可能成为决策点的地方都算对了,而是要求 f_l 的决策点到 f_r 的大致决策点内的东西都算对了。我们需要:
-
- 只考虑 0\sim l 的话,r 的 f 值和决策点都算对了。
设 mid=(l+r)/2,那么我们需要做的事情是:
- 把 l 的决策点到 r 的当前决策点之间的点转移到 mid,更新 f_{mid} 的值与决策点。
- 递归 \operatorname{solve}(l,mid)。
- 把 l+1 到 mid 之间的点转移到 r,更新 f_r 的值与决策点。
- 递归 \operatorname{solve}(mid,r)。
然后我们就把所有 f 值算对了。
为什么?首先看第一步。由决策单调性,f_{mid} 在只考虑 0\sim l 时的决策点,必然在 f_l 的决策点与 f_r 的决策点之间,所以第二步的递归前提是满足的。第三步中,f_r 多考虑了 l+1\sim mid 的部分,因此第四步的递归前提仍是满足的。而对于最后一层的递归 \operatorname{solve}(k-1,k),f_k 已经考虑了 0\sim k-1 的所有位置,因此运行结束后所有位置的 f 都算对了。
时间复杂度是多少呢?可以发现,对于分治的每一层,我们都相当于把所有点遍历了一遍,所以复杂度是 T(n)=2T(n/2)+\mathcal{O}(n)=\mathcal{O}(n\log n) 的。
当然该算法也有局限性。具体的,由于本算法对决策点进行了部分估计,在限制不够强时可能无法包含最优决策。因此 dp 的转移应当满足四边形不等式。
例题
wqs 二分使用例:[ABC355G] Baseball
设 dp_{i,j} 表示前 i 个位置放了 j 个点的方案数,则有
dp_{i,j}=\min_kdp_{k,j-1}+w(k,i)
$$f_i=\min_j f_j+w(j,i)$$
用上述算法解决,时间复杂度 $\mathcal{O}(n\log n\log V)$。
:::info[Code]
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 5e4 + 10;
int n, m; ll s[MAXN], si[MAXN];
inline
ll w(int l, int r) {
if (l > r) return 0;
if (l == 1 && r == n) return 1e18;
if (l == 1) return (r + 1) * s[r] - si[r];
if (r == n) return (si[n] - si[l - 1]) - (l - 1) * (s[n] - s[l - 1]);
int mid = l + r >> 1;
return (si[mid] - si[l - 1]) - (l - 1) * (s[mid] - s[l - 1])
+ (r + 1) * (s[r] - s[mid]) - (si[r] - si[mid]);
}
ll dp[MAXN], X; int cnt[MAXN], p[MAXN];
inline
void check(int i, int j) {
ll x = dp[j] + w(j + 1, i - 1) + X;
if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j;
else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j;
}
void solve(int l, int r) {
if (r - l == 1) return ; int mid = l + r >> 1;
for (int i = p[l]; i <= p[r]; i++) check(mid, i);
solve(l, mid);
for (int i = l + 1; i <= mid; i++) check(r, i);
solve(mid, r);
}
inline
bool check() {
for (int i = 1; i <= n + 1; i++) dp[i] = 1e18, cnt[i] = p[i] = 0;
check(n + 1, 0), solve(0, n + 1);
return cnt[n + 1] <= m;
}
ll l, r, ans;
int main() {
scanf("%d%d", &n, &m), m++;
for (int i = 1; i <= n; i++) scanf("%lld", &s[i]);
for (int i = 1; i <= n; i++) si[i] = si[i - 1] + i * s[i];
for (int i = 1; i <= n; i++) s[i] += s[i - 1];
for (l = 0, r = 1e10; l <= r; ) {
X = l + r >> 1;
if (check()) r = X - 1, ans = X;
else l = X + 1;
}
X = ans, check(), printf("%lld", dp[n + 1] - X * m);
}
```
:::
### 莫队使用例:[CF868F Yet Another Minimization Problem](https://www.luogu.com.cn/problem/CF868F)
非常经典的题了。由于转移是离线的,可以用常规的整体二分方法解决。这里尝试使用本文介绍的算法把它做掉。
容易发现,第一步和第三步的莫队都满足移动次数 $\mathcal{O}(n\log n)$ 的性质,分开跑即可。
**注意不要用同一个莫队跑**。否则指针会在决策点与区间内反复横跳,导致复杂度退化为平方级别。
:::info[Code]
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 10;
int a[MAXN];
struct node {
int cnt[MAXN], l, r; ll ans;
node() : l(1), r(0), ans(0) { memset(cnt, 0, sizeof cnt); }
inline
ll w(int ql, int qr) {
if (ql > qr) return 0;
for (; l > ql; ans += cnt[a[--l]]++);
for (; r < qr; ans += cnt[a[++r]]++);
for (; l < ql; ans -= --cnt[a[l++]]);
for (; r > qr; ans -= --cnt[a[r--]]);
return ans;
}
} A, B;
ll dp[MAXN][30]; int p[MAXN][30];
inline
void checkA(int i, int j, int k) {
ll x = dp[j][k - 1] + A.w(j + 1, i);
if (x < dp[i][k]) dp[i][k] = x, p[i][k] = j;
}
inline
void checkB(int i, int j, int k) {
ll x = dp[j][k - 1] + B.w(j + 1, i);
if (x < dp[i][k]) dp[i][k] = x, p[i][k] = j;
}
void solve(int l, int r, int k) {
if (r - l == 1) return ; int mid = l + r >> 1;
for (int i = p[l][k]; i <= p[r][k]; i++) checkA(mid, i, k);
solve(l, mid, k);
for (int i = l + 1; i <= mid; i++) checkB(r, i, k);
solve(mid, r, k);
}
int n, k;
int main() {
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
memset(dp, 0x3f, sizeof dp), **dp = 0;
for (int i = 1; i <= k; i++) checkA(n, 0, i), solve(0, n, i);
printf("%lld", dp[n][k]);
}
```
:::
### 莫队使用例:[P9266 [PA 2022] Nawiasowe podziały](https://www.luogu.com.cn/problem/P9266)
设一个区间的代价为 $w(l,r)$,即区间 $[l,r]$ 内的合法括号子串个数。我们有 $w(l,r)+w(l+1,r-1)\ge w(l+1,r)+w(l,r-1)$。因此转移满足四边形不等式,同时可以用 wqs 二分去掉段数的限制。
但这时你发现个问题,这个代价很难用莫队以外的办法算出来。这导致二分队列直接倒闭了。同时,转移是在线的,整体二分还要套 cdq 才能处理。
本讲方法的优越性就体现出来了,可以直接做到 $\mathcal{O}(n\log n\log V)$,常数较小。
:::info[Code]
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 10;
int n, m, pos[MAXN], col[MAXN]; char s[MAXN];
struct node {
int cnt[MAXN], l, r; ll ans;
node() : l(1), r(0), ans(0) { memset(cnt, 0, sizeof cnt); }
inline
ll w(int ql, int qr) {
if (ql > qr) return 0;
for (; l > ql; ) {
l--;
if (s[l] == '(' && pos[l] <= r) ans += ++cnt[col[l]];
}
for (; r < qr; ) {
r++;
if (s[r] == ')' && pos[r] >= l) ans += ++cnt[col[r]];
}
for (; l < ql; ) {
if (s[l] == '(' && pos[l] <= r) ans -= cnt[col[l]]--;
l++;
}
for (; r > qr; ) {
if (s[r] == ')' && pos[r] >= l) ans -= cnt[col[r]]--;
r--;
}
return ans;
}
} A, B;
ll dp[MAXN], X; int cnt[MAXN], p[MAXN];
inline
void checkA(int i, int j) {
ll x = dp[j] + A.w(j + 1, i) + X;
if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j;
else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j;
}
inline
void checkB(int i, int j) {
ll x = dp[j] + B.w(j + 1, i) + X;
if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j;
else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j;
}
void solve(int l, int r) {
if (r - l == 1) return ; int mid = l + r >> 1;
for (int i = p[l]; i <= p[r]; i++) checkA(mid, i);
solve(l, mid);
for (int i = l + 1; i <= mid; i++) checkB(r, i);
solve(mid, r);
}
inline
bool check() {
for (int i = 1; i <= n; i++) dp[i] = 1e18, cnt[i] = p[i] = 0;
checkA(n, 0), solve(0, n);
return cnt[n] <= m;
}
int st[MAXN], tp, id;
ll l, r, ans;
int main() {
scanf("%d%d%s", &n, &m, s + 1);
for (int i = 1; i <= n; i++) {
if (s[i] == '(') { st[++tp] = i, pos[i] = n + 1; continue; }
if (!tp) continue; int j = st[tp--]; pos[i] = j, pos[j] = i;
col[i] = col[j] = (col[j - 1] ? col[j - 1] : ++id);
}
for (l = 0, r = 1e10; l <= r; ) {
X = l + r >> 1;
if (check()) r = X - 1, ans = X;
else l = X + 1;
}
X = ans, check();
printf("%lld", dp[n] - X * m);
}
```
:::