题解:AT_abc383_g [ABC383G] Bar Cover

· · 题解

本题解证明比较粗略,欢迎 Hack 或补充。

提供一种复杂度与 k 无关的算法。

a 的前缀和为 s

首先考虑 dp,设 f_{i,j} 为前 i 个数中已经选了 j 个区间了,列出方程:

f_{i,j}=\max(f_{i-1,j},f_{i-k,j-1}+s_i-s_{i-k})

考虑优化这个东西。

根据数学直觉,发现其每个 f_i 都是上凸的。

这启发我们使用 slope trick(参考 CF1787H)。

套路的,设出差分数组 g_{i,j}=f_{i,j}-f_{i,j-1},那么 f_{i,j}=\sum_{l=1}^{j} g_{i,l}

代入原转移方程:

\begin{aligned} g_{i,j}&=\max(\sum_{l=1}^{j} g_{i-1,l},\sum_{l=1}^{j-1} g_{i-k,l}+s_i-s_{i-k})\\ &-\max(\sum_{l=1}^{j-1} g_{i-1,l},\sum_{l=1}^{j-2} g_{i-k,l}+s_i-s_{i-k}) \end{aligned}

然后就要计算这些 \max 取到哪个,因为根据数学直觉,f_{i-1} 肯定比 f_{i-k} 更凸,

所以猜测或打表,发现对于 \max(\sum_{l=1}^{j} g_{i-1,l},\sum_{l=1}^{j-1} g_{i-k,l}+s_i-s_{i-k}),存在一个分割点,使得所有的 p \le j 取到前者,对于所有的 p > j 取到后者。

所以对于 g_{i,j} 的转移有三种情况:

g_{i,j}= \left\{\begin{matrix} g_{i-1,j} & j \le p\\ \sum_{l=1}^{j-1} g_{i-k,l}+s_i-s_{i-k}-\sum_{l=1}^{j-1} g_{i-1,l} & j=p+1\\ g_{i-k,j-1} & j > p + 1 \end{matrix}\right.

然后应该可以通过归纳法,证明出 f_{i} 是凸的。

然后考虑如何维护这个,类似 CF1787H,不难想到使用持久化平衡树,通过二分 p 再在两颗平衡树上二分来找到 p,然后从 g_{i-1}g_{i-k} 的平衡树上裂出一部分拼到 g_i 即可。

时间复杂度 O(n \log^2 n)

使用了 Leafy Tree 实现。

const int N = 2e5 + 5;
const int M = 3e7 + 5;
const ll LNF = 1e18;
int n, k, m; 
ll a[N], s[N];
int rt[N], tot, ls[M], rs[M], sz[M]; ll F[M];
ll ans[N]; int cnt;
int add() {
    return ++ tot;
}
int add(ll x) {
    int u = add();
    sz[u] = 1;
    F[u] = x;
    return u;
}
void up(int u) {
    sz[u] = sz[ls[u]] + sz[rs[u]];
    F[u] = F[ls[u]] + F[rs[u]];
}
int up(int l, int r) {
    int u = add();
    ls[u] = l, rs[u] = r;
    up(u);
    return u;
}
int merge(int u, int v) {
    if(! u || ! v) return u | v;
    if(sz[u] <= sz[v] * 4 && sz[v] <= sz[u] * 4) {
        return up(u, v);
    }
    if(sz[u] >= sz[v]) {
        int l = ls[u], r = rs[u];
        if(sz[l] * 4 > (sz[u] + sz[v])) return merge(l, merge(r, v));
        return merge(merge(l, ls[r]), merge(rs[r], v));
    }
    else {
        int l = ls[v], r = rs[v];
        if(sz[r] * 4 > (sz[u] + sz[v])) return merge(merge(u, l), r);
        return merge(merge(u, ls[l]), merge(rs[l], r));
    }
}
void split(int u, int p, int & x, int & y) {
    if(! u || ! p) {
        x = 0, y = u;
        return;
    }
    if(sz[u] == p) {
        x = u, y = 0;
        return;
    }
    if(p <= sz[ls[u]]) {
        split(ls[u], p, x, y);
        y = merge(y, rs[u]);
    }
    else {
        split(rs[u], p - sz[ls[u]], x, y);
        x = merge(ls[u], x);
    }
}
ll query(int u, int r) {
    if(! u || ! r) return 0;
    if(sz[u] <= r) {
        return F[u];
    }
    if(r <= sz[ls[u]]) return query(ls[u], r);
    return F[ls[u]] + query(rs[u], r - sz[ls[u]]);
}
void print(int u) {
    if(! ls[u]) {
        ans[++ cnt] = F[u];
        return;
    }
    print(ls[u]);
    print(rs[u]);
}
void solve() {
    cin >> n >> k; m = n / k;
    FOR(i, 1, n) cin >> a[i];
    FOR(i, 1, n) s[i] = s[i - 1] + a[i];
    REP(i, k) rt[i] = add(- LNF);
    FOR(i, k, n) {
        int L = 1, R = i / k, pos = 0;
        while(L <= R) {
            int mid = L + R >> 1;
            ll vl = query(rt[i - 1], mid);
            ll vr = query(rt[i - k], mid - 1) + s[i] - s[i - k];
            if(vl >= vr) {
                pos = mid;
                L = mid + 1;
            }
            else {
                R = mid - 1;
            }
        }
        ll vl = query(rt[i - 1], pos + 1);
        ll vr = query(rt[i - k], pos) + s[i] - s[i - k];
        if(pos == i / k) {
            rt[i] = rt[i - 1];
        }
        else if(! pos && vl < vr) {
            rt[i] = merge(add(s[i] - s[i - k]), rt[i - k]);
        }
        else {
            int x, y, z;
            split(rt[i - 1], pos, x, z);
            split(rt[i - k], pos, z, y);
            ll val = query(rt[i - k], pos) - query(rt[i - 1], pos) + s[i] - s[i - k];
            rt[i] = merge(merge(x, add(val)), y);
        }
    }
    print(rt[n]);
    FOR(i, 1, n / k) ans[i] += ans[i - 1];
    FOR(i, 1, n / k) cout << ans[i] << " "; cout << endl;
}