P9732 [CEOI 2023] Trade

· · 题解

或许更好的阅读体验。

思路:

容易发现,这里 w(l, r) 的值是 A 区间前 k 大减去 B 区间的和;然后问的是所有 w(l, r) 的最大值,以及判断每个位置是否可能是某个最大的 w(l, r) 的前 k 大。

对于一个固定的右端点 r,设以 r 为右端点的最大权值是 f_r

于是问题变为求 f_r;考虑证明 w(l, r) 满足四边形不等式:

于是满足决策单调性,可以直接分治,算 w(l, r) 的时候直接主席树即可。

然后考虑第二问怎么做,考虑首先找出所有最优区间,但是这种可能很多怎么办?继续应用四边形不等式发现性质,若 a < b \le c < d[a, d][b, c] 是最优区间,根据四边形 w(a, c) + w(b, d) \ge w(b, c) + w(a, d) = 2 mx,于是 [a, c], [b, d] 必然也是最优区间。然后你注意到 [a, d] 的前 k 大在 [a, c][b, d] 中一定也是前 k 大;于是我们不需要处理包含关系。

于是可以走双指针求出右端点递增时左端点不降的所有最优区间,具体的,对于一个右端点 r 求出最大的左端点 opt_r 使得 [opt_r, r] 是最优区间;然后考虑 r 前面一个最优区间的左端点 l,那么只有左端点在 [l, opr_r] 这个区间且右端点是 r 的最优区间才可能对答案造成贡献。

找到了可以对答案造成贡献的区间,现在即我们需要支持给一个区间中前 k 大的数打标记,容易主席树找到区间第 kv;于是只需要支持给区间 [l, r]\ge v 的位置打标记即可。

于是直接对序列扫描线,对于每个位置 i,找到所有覆盖它的操作 (l, r, v) 中的 v 的最小值 v_{\min},如果 a_i \ge v_{\min},那么 i 必然被打了标记;使用 multiset 维护所有的 v 即可。

时间复杂度为 O(n \log^2 n)

完整代码:

#include<bits/stdc++.h>
#define ls(k) k << 1
#define rs(k) k << 1 | 1
#define lowbit(x) x & (-x)
#define fi first
#define se second
#define popcnt(x) __builtin_popcount(x)
#define open(s1, s2) freopen(s1, "r", stdin), freopen(s2, "w", stdout);
using namespace std;
typedef __int128 __;
typedef long double lb;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const int N = 2.5e5 + 10;
const ll inf = 1e18;
inline ll read(){
    ll x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-')
            f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return x * f;
}
inline void write(ll x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}
struct Node{
    int lson, rson;
    int num;
    ll sum;
}X[N << 5];
ll ans = -inf;
int n, k, cnt, num;
int a[N], h[N], rt[N];
ll s[N], f[N], opt[N];
multiset<int> S;
vector<int> In[N], Del[N];
inline void update(int &k, int l, int r, int i){
    X[++cnt] = X[k];
    k = cnt;
    ++X[k].num;
    X[k].sum += h[i];
    if(l == i && i == r)
      return ;
    int mid = (l + r) >> 1;
    if(i <= mid)
      update(X[k].lson, l, mid, i);
    else 
      update(X[k].rson, mid + 1, r, i);
}
inline int kth(int k1, int k2, int l, int r, int k){
    if(l == r)
        return l;
    int mid = (l + r) >> 1;
    if(X[X[k1].rson].num - X[X[k2].rson].num >= k)
        return kth(X[k1].rson, X[k2].rson, mid + 1, r, k);
    else
        return kth(X[k1].lson, X[k2].lson, l, mid, k - (X[X[k1].rson].num - X[X[k2].rson].num));
}
inline ll kth_sum(int k1, int k2, int l, int r, int k){
    if(l == r)
        return 1ll * k * h[l];
    int mid = (l + r) >> 1;
    if(X[X[k1].rson].num - X[X[k2].rson].num >= k)
        return kth_sum(X[k1].rson, X[k2].rson, mid + 1, r, k);
    else
        return X[X[k1].rson].sum - X[X[k2].rson].sum + kth_sum(X[k1].lson, X[k2].lson, l, mid, k - (X[X[k1].rson].num - X[X[k2].rson].num));
}
inline ll getw(int l, int r){
    if(r - l + 1 < k)
      return -inf;
    return kth_sum(rt[r], rt[l - 1], 1, num, k) - (s[r] - s[l - 1]);
}
inline void solve(int l, int r, int kl, int kr){
    if(l > r || kl > kr)
      return ;
    int mid = (l + r) >> 1, now = 0;
    for(int i = min(mid, kr); i >= kl; --i){
        ll t = getw(i, mid);
        if(t > f[mid]){
            f[mid] = t;
            now = i;
        }
    }
    opt[mid] = now;
    solve(l, mid - 1, kl, now);
    solve(mid + 1, r, now, kr);
}
int main(){
    n = read(), k = read();
    for(int i = 1; i <= n; ++i)
      s[i] = s[i - 1] + read();
    for(int i = 1; i <= n; ++i)
      h[++num] = a[i] = read();
    sort(h + 1, h + num + 1);
    num = unique(h + 1, h + num + 1) - (h + 1);
    for(int i = 1; i <= n; ++i){
        rt[i] = rt[i - 1];
        a[i] = lower_bound(h + 1, h + num + 1, a[i]) - h;
        update(rt[i], 1, num, a[i]);
        f[i] = -inf;
    }
    solve(k, n, 1, n);
    for(int i = k; i <= n; ++i)
      ans = max(ans, f[i]);
    int lst = 1;
    for(int i = k; i <= n; ++i){
        if(f[i] != ans)
          continue;
//      cerr << i << ' ' << opt[i] << '\n';
        for(int l = lst; l <= opt[i]; ++l){
            if(getw(l, i) == ans){
                int v = kth(rt[i], rt[l - 1], 1, num, k);
//              cerr << l << ' ' << i << ' ' << v << '\n';
                In[l].push_back(v);
                Del[i].push_back(v);
            }
        }
        lst = opt[i];
    }
    write(ans);
    putchar('\n');
    for(int i = 1; i <= n; ++i){
        for(auto v : In[i])
          S.insert(v);
        if(S.empty())
          putchar('0');
        else
          putchar(a[i] >= (*S.begin()) ? '1' : '0');
        for(auto v : Del[i])
          S.erase(S.find(v));
    }
    return 0;
}