在线决策单调性地皮还能单老哥分治做?

· · 算法·理论

原博客: 簡易版 LARSCH Algorithm noshi91。

众所周知决策单调性 dp 有很多做法,大家应该都会。但是主播主播,二分队列遇上要莫队算的贡献还是太菜了,整体二分还是太吃离线了,SMAWK 算法和 Wilber 算法又太吃常数和码量了。那么有没有什么常数小、能处理莫队、还好写的决策单调性优化 dp 呢?

有的兄弟有的,这就是我们要介绍的:

简易版 LARSCH 算法

设 dp 数组为 f。考虑分治。尝试设计一个 \operatorname{solve}(l,r) 函数。我们希望运行完这个函数后,f_{l+1\sim r} 的值都已经算对了。但是考虑分治中心 mid,我们如果要求 f_{mid} 一定要算对,那么 0\sim mid-1 的东西都要算对,这显然是不现实的。所以我们需要分成若干段考虑贡献。

具体地,我们在算 mid 时,不要求所有可能成为决策点的地方都算对了,而是要求 f_l 的决策点到 f_r大致决策点内的东西都算对了。我们需要:

mid=(l+r)/2,那么我们需要做的事情是:

  1. l 的决策点到 r当前决策点之间的点转移到 mid,更新 f_{mid} 的值与决策点。
  2. 递归 \operatorname{solve}(l,mid)
  3. l+1mid 之间的点转移到 r,更新 f_r 的值与决策点。
  4. 递归 \operatorname{solve}(mid,r)

然后我们就把所有 f 值算对了。

为什么?首先看第一步。由决策单调性,f_{mid} 在只考虑 0\sim l 时的决策点,必然在 f_l 的决策点与 f_r 的决策点之间,所以第二步的递归前提是满足的。第三步中,f_r 多考虑了 l+1\sim mid 的部分,因此第四步的递归前提仍是满足的。而对于最后一层的递归 \operatorname{solve}(k-1,k)f_k 已经考虑了 0\sim k-1 的所有位置,因此运行结束后所有位置的 f 都算对了。

时间复杂度是多少呢?可以发现,对于分治的每一层,我们都相当于把所有点遍历了一遍,所以复杂度是 T(n)=2T(n/2)+\mathcal{O}(n)=\mathcal{O}(n\log n) 的。

当然该算法也有局限性。具体的,由于本算法对决策点进行了部分估计,在限制不够强时可能无法包含最优决策。因此 dp 的转移应当满足四边形不等式

例题

wqs 二分使用例:[ABC355G] Baseball

dp_{i,j} 表示前 i 个位置放了 j 个点的方案数,则有

dp_{i,j}=\min_kdp_{k,j-1}+w(k,i) $$f_i=\min_j f_j+w(j,i)$$ 用上述算法解决,时间复杂度 $\mathcal{O}(n\log n\log V)$。 :::info[Code] ```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; const int MAXN = 5e4 + 10; int n, m; ll s[MAXN], si[MAXN]; inline ll w(int l, int r) { if (l > r) return 0; if (l == 1 && r == n) return 1e18; if (l == 1) return (r + 1) * s[r] - si[r]; if (r == n) return (si[n] - si[l - 1]) - (l - 1) * (s[n] - s[l - 1]); int mid = l + r >> 1; return (si[mid] - si[l - 1]) - (l - 1) * (s[mid] - s[l - 1]) + (r + 1) * (s[r] - s[mid]) - (si[r] - si[mid]); } ll dp[MAXN], X; int cnt[MAXN], p[MAXN]; inline void check(int i, int j) { ll x = dp[j] + w(j + 1, i - 1) + X; if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j; else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j; } void solve(int l, int r) { if (r - l == 1) return ; int mid = l + r >> 1; for (int i = p[l]; i <= p[r]; i++) check(mid, i); solve(l, mid); for (int i = l + 1; i <= mid; i++) check(r, i); solve(mid, r); } inline bool check() { for (int i = 1; i <= n + 1; i++) dp[i] = 1e18, cnt[i] = p[i] = 0; check(n + 1, 0), solve(0, n + 1); return cnt[n + 1] <= m; } ll l, r, ans; int main() { scanf("%d%d", &n, &m), m++; for (int i = 1; i <= n; i++) scanf("%lld", &s[i]); for (int i = 1; i <= n; i++) si[i] = si[i - 1] + i * s[i]; for (int i = 1; i <= n; i++) s[i] += s[i - 1]; for (l = 0, r = 1e10; l <= r; ) { X = l + r >> 1; if (check()) r = X - 1, ans = X; else l = X + 1; } X = ans, check(), printf("%lld", dp[n + 1] - X * m); } ``` ::: ### 莫队使用例:[CF868F Yet Another Minimization Problem](https://www.luogu.com.cn/problem/CF868F) 非常经典的题了。由于转移是离线的,可以用常规的整体二分方法解决。这里尝试使用本文介绍的算法把它做掉。 容易发现,第一步和第三步的莫队都满足移动次数 $\mathcal{O}(n\log n)$ 的性质,分开跑即可。 **注意不要用同一个莫队跑**。否则指针会在决策点与区间内反复横跳,导致复杂度退化为平方级别。 :::info[Code] ```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; const int MAXN = 1e5 + 10; int a[MAXN]; struct node { int cnt[MAXN], l, r; ll ans; node() : l(1), r(0), ans(0) { memset(cnt, 0, sizeof cnt); } inline ll w(int ql, int qr) { if (ql > qr) return 0; for (; l > ql; ans += cnt[a[--l]]++); for (; r < qr; ans += cnt[a[++r]]++); for (; l < ql; ans -= --cnt[a[l++]]); for (; r > qr; ans -= --cnt[a[r--]]); return ans; } } A, B; ll dp[MAXN][30]; int p[MAXN][30]; inline void checkA(int i, int j, int k) { ll x = dp[j][k - 1] + A.w(j + 1, i); if (x < dp[i][k]) dp[i][k] = x, p[i][k] = j; } inline void checkB(int i, int j, int k) { ll x = dp[j][k - 1] + B.w(j + 1, i); if (x < dp[i][k]) dp[i][k] = x, p[i][k] = j; } void solve(int l, int r, int k) { if (r - l == 1) return ; int mid = l + r >> 1; for (int i = p[l][k]; i <= p[r][k]; i++) checkA(mid, i, k); solve(l, mid, k); for (int i = l + 1; i <= mid; i++) checkB(r, i, k); solve(mid, r, k); } int n, k; int main() { scanf("%d%d", &n, &k); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); memset(dp, 0x3f, sizeof dp), **dp = 0; for (int i = 1; i <= k; i++) checkA(n, 0, i), solve(0, n, i); printf("%lld", dp[n][k]); } ``` ::: ### 莫队使用例:[P9266 [PA 2022] Nawiasowe podziały](https://www.luogu.com.cn/problem/P9266) 设一个区间的代价为 $w(l,r)$,即区间 $[l,r]$ 内的合法括号子串个数。我们有 $w(l,r)+w(l+1,r-1)\ge w(l+1,r)+w(l,r-1)$。因此转移满足四边形不等式,同时可以用 wqs 二分去掉段数的限制。 但这时你发现个问题,这个代价很难用莫队以外的办法算出来。这导致二分队列直接倒闭了。同时,转移是在线的,整体二分还要套 cdq 才能处理。 本讲方法的优越性就体现出来了,可以直接做到 $\mathcal{O}(n\log n\log V)$,常数较小。 :::info[Code] ```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; const int MAXN = 1e5 + 10; int n, m, pos[MAXN], col[MAXN]; char s[MAXN]; struct node { int cnt[MAXN], l, r; ll ans; node() : l(1), r(0), ans(0) { memset(cnt, 0, sizeof cnt); } inline ll w(int ql, int qr) { if (ql > qr) return 0; for (; l > ql; ) { l--; if (s[l] == '(' && pos[l] <= r) ans += ++cnt[col[l]]; } for (; r < qr; ) { r++; if (s[r] == ')' && pos[r] >= l) ans += ++cnt[col[r]]; } for (; l < ql; ) { if (s[l] == '(' && pos[l] <= r) ans -= cnt[col[l]]--; l++; } for (; r > qr; ) { if (s[r] == ')' && pos[r] >= l) ans -= cnt[col[r]]--; r--; } return ans; } } A, B; ll dp[MAXN], X; int cnt[MAXN], p[MAXN]; inline void checkA(int i, int j) { ll x = dp[j] + A.w(j + 1, i) + X; if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j; else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j; } inline void checkB(int i, int j) { ll x = dp[j] + B.w(j + 1, i) + X; if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j; else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j; } void solve(int l, int r) { if (r - l == 1) return ; int mid = l + r >> 1; for (int i = p[l]; i <= p[r]; i++) checkA(mid, i); solve(l, mid); for (int i = l + 1; i <= mid; i++) checkB(r, i); solve(mid, r); } inline bool check() { for (int i = 1; i <= n; i++) dp[i] = 1e18, cnt[i] = p[i] = 0; checkA(n, 0), solve(0, n); return cnt[n] <= m; } int st[MAXN], tp, id; ll l, r, ans; int main() { scanf("%d%d%s", &n, &m, s + 1); for (int i = 1; i <= n; i++) { if (s[i] == '(') { st[++tp] = i, pos[i] = n + 1; continue; } if (!tp) continue; int j = st[tp--]; pos[i] = j, pos[j] = i; col[i] = col[j] = (col[j - 1] ? col[j - 1] : ++id); } for (l = 0, r = 1e10; l <= r; ) { X = l + r >> 1; if (check()) r = X - 1, ans = X; else l = X + 1; } X = ans, check(); printf("%lld", dp[n] - X * m); } ``` :::