PA2021 Wystawa

· · 题解

题意:给定两个长度为 n 的序列 a,b,试构造一个同样长 n 的序列 c,满足有恰好 k 个位置,c_i\leftarrow a_i,其它位置 c_i\leftarrow b_i,请输出所有合法的 c 里最大子段和最小的一个,如果有多个,输出任意一个即可。

直接构造一个最小的不是很好做,但是可以构造一个不超过 x 的出来,即可以判定一个值是否合法,所以考虑二分。

对于 check,考虑这样一点:当 a_i<b_i 时,我们选 A 肯定是更好的,假设有 k^{'} 个点满足这个条件,那么我们只需要把这些位置填 A,然后替换 k^{'}-k 个点。

怎么替换呢?显然是替换代价最小的那些点。当出现前缀的最大子段和 g \leq 0 的时候考虑替换,注意:若换了以后最大子段和会变成一个正数,那么不应该进行这次替换,同时把这个 A 换掉的代价应该加上 g(原因显然)。在求完整个序列的最大子段和后,我们再考虑替换剩下的数,只要选出最小的一些加到 g 上就好了。

我们还需要维护一个数应不应该被替换,显然,如果一个位置应该被替换,则代价比它小的位置应该被替换;那么,维护另一个最大子段和 h,对应的选数策略是能选 B 就选 B,如果这个 h 超过了 x,那么此时替换代价最大的位置就不应该被换掉,让 h 减去它的代价,然后删掉它。

需要维护每个位置的代价,然后支持单点赋值(或插入删除),全局最值。随便搞一个数据结构都行。

当然,我们上面的假设基于 k^{'}\geq k,否则,考虑交换 AB,然后填 n-k 个新的 A 即可,注意 a_i=b_i 的处理,我们之前的做法是只在 a_i\lt b_i 时才选择 a_i,即相等时选择 b_i,交换后也应该保持这个方案。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 2e5 + 5;
struct node {
    int pos;
    ll val;
    node(int _pos = -1, ll _val = -2e14) {
        val = _val, pos = _pos;
    }
    bool operator < (const node &o) const {
        return (val ^ o.val) ? (val < o.val) : (pos < o.pos);
    }
};
int n, k, kk, a[N], b[N], swi;
char ans[N];
inline bool check(ll x) {
    ll g = 0, h = 0;
    int rest = kk - k;
    set<node> S;
    for (int i = 1; i <= n; i++) {
        if (a[i] < b[i] + swi)
            g += a[i], ans[i] = 'A', S.insert(node(i, b[i] - a[i]));
        else
            g += b[i], ans[i] = 'B';
        if (g > x)
            return false;
        while (rest && g <= 0ll && !S.empty()) {
            auto it = S.begin();
            if (g + it->val <= 0ll) 
                ans[it->pos] = 'B', g += it->val, rest--, S.erase(it);
            else {
                S.erase(it), S.insert(node(it->pos, g + it->val));
                break;
            }
        }
        g = max(g, 0ll), h = max(h + b[i], 0ll);
        while (h > x && !S.empty()) {
            auto it = --S.end();
            h -= it->val, S.erase(it);
        }
    }
    while (rest && !S.empty()) {
        auto it = S.begin();
        ans[it->pos] = 'B', g += it->val, rest--, S.erase(it);
    }
    return rest == 0 && g <= x;
}
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    cin >> n >> k;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i <= n; i++)
        cin >> b[i];
    for (int i = 1; i <= n; i++)
        if (a[i] < b[i])
            kk++;
    swi = 0;
    ll l = 0, r = 2e14, ans_val = -1;
    if (kk < k) {
        kk = n - kk, k = n - k, swi = 1;
        for (int i = 1; i <= n; i++)
            swap(a[i], b[i]);
    }
    while (l <= r) {
        ll mid = (l + r) >> 1;
        if (check(mid))
            ans_val = mid, r = mid - 1;
        else
            l = mid + 1;
    }
    assert(check(ans_val));
    cout << ans_val << '\n';
    for (int i = 1; i <= n; i++)
        ans[i] = (!swi ? "AB" : "BA")[ans[i] - 'A'];
    for (int i = 1; i <= n; i++)
        cout << ans[i];
    cout << endl;
    return 0;
}