P9338 [JOISC 2023 Day3] Chorus

· · 题解

感觉完全做不来这种题/ll.

首先发现 =k\leq k 其实是等价的,先考虑对于一个确定的串,怎样划分能够使得子序列的数量尽量少。这比较简单,只需要每次贪心的往后取就行了。然后我们考虑一下大概会怎么交换,好像每次是给一个区间做排序状物,但似乎没有什么好的刻画方法,于是突然就做不下去了。

考虑一个神秘观察:把 S 看成在 n \times n 的网格中的一个路径,其中 A 表示向右一步,B 表示向上一步,显然一次操作只会是把一个上右改成右上。当然有解的一个必要条件是它必须在对角线下方,我们把在对角线上方的部分改到对角线下方,代价提前计算掉,现在只考虑路径在对角线之下的情况。

考虑在路径上做上面那个贪心,相当于每次沿着路径往右走,碰到一个向上就一直向上直到碰到对角线再向右,两次拐弯衔接的地方可能会平移一部分路径,但这个是合法的。

于是要求 k 个子序列,相当于是拐弯了 k 次,每个拐弯对应一个区间。现在假设已经确定了最后的这 k 个拐弯长成什么样,我们发现达到这个状态的代价恰好就是夹在钦定的路径上方,原路径下方的格子数量。具体可以看下面这个从官方题解偷来的图。

w(l,r) 为区间 [l,r] 对应拐弯,即从 (l,l) 走到 (r,r) 的代价,假设原点坐标是 (1,1)。那么现在问题相当于是要划分成 k 个区间,最小化代价和。老套路了,容易验证它满足四边形不等式,所以整个问题是凸的。考虑 WQS 二分,之后用利用决策单调性分治 / SMAWK / LARSCH 等等优化转移技巧可以做到 \mathcal{O}(n \log^3 n)\mathcal{O}(n \log n) 不等的时间复杂度。但它们要么复杂度太高,要么代码太难写,都不是我们想要的。

让我们暂时忘掉决策单调性。考虑给 w(l,r) 一个更好的表征:设 t_i 表示第 i 向上前向右的次数,pre_i 表示 t_i 的前缀和,cnt_isum_i 分别表示满足 t_j \leq ij 的个数和 t_j 的和,这些都容易 \mathcal{O}(n) 预处理。那么 w(l,r) = (cnt_{r-1} - l + 1) \times (r - 1) - sum_{r-1} + pre_{l-1}。拆开用斜率优化即可。

总时间复杂度 \mathcal{O}(n)。有时候知道得太多也不一定是一件好事。

#include <bits/stdc++.h>
using namespace std;
typedef pair <int, int> pi;
typedef long long LL;
constexpr int N = 2e6 + 5;
constexpr LL inf = 1e18;
int n, k, t[N], g[N]; LL cnt[N], sum[N], pre[N], f[N], ans, ns;
int q[N], hd, tl;
LL X(int i) { return i; }
LL Y(int i) {
    return f[i] + pre[i - 1] + i;
}
bool chk(LL m) {
    fill(f + 1, f + n + 2, inf);
    fill(g + 1, g + n + 2, k + 1);
    f[1] = g[1] = 0;
    q[hd = tl = 1] = 1;
    for (int i = 2; i <= n + 1; i++) {
        while (hd < tl && i * (X(q[hd + 1]) - X(q[hd])) > (Y(q[hd + 1]) - Y(q[hd]))) hd++;
        LL val = -X(q[hd]) * i + Y(q[hd]);
        f[i] = val + (cnt[i - 1] + 1) * (i - 1) - sum[i - 1] - m, g[i] = g[q[hd]] + 1;
        while (hd < tl && (Y(i) - Y(q[tl])) * (X(q[tl]) - X(q[tl - 1])) < (Y(q[tl]) - Y(q[tl - 1])) * (X(i) - X(q[tl]))) tl--;
        q[++tl] = i;
    } 
    ans = f[n + 1];
    return g[n + 1] <= k;
}
signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> k;
    string str;
    cin >> str;
    str = ' ' + str;
    int cA = 0, cB = 0;
    for (int i = 1; i <= 2 * n; i++) {
        if (str[i] == 'A') ++cA;
        else ++cB, t[cB] = cA;
    }
    for (int i = 1; i <= n; i++) if (t[i] < i) ns += i - t[i], t[i] = i;
    for (int i = 1; i <= n; i++) pre[i] = pre[i - 1] + t[i];
    for (int i = 1; i <= n; i++) ++cnt[t[i]], sum[t[i]] += t[i];
    for (int i = 1; i <= n; i++) cnt[i] += cnt[i - 1], sum[i] += sum[i - 1];
    LL l = -5e11, r = 0, p = 0;
    while (l <= r) {
        LL m = (l + r) >> 1;
        if (chk(m)) p = m, l = m + 1;
        else r = m - 1;
    }
    chk(p);
    cout << ns + ans + 1LL * p * k << "\n";
    return 0;
}