CF1784C Monsters (hard version) 题解

· · 题解

分析

沿用和 easy version 一样的思路,即操作 1 会将原序列减小至连续递增状态再进行操作 2。

以下的操作均指操作 1。

从小到大考虑序列中的元素,可以发现对于前 i 个数,只会对 \min(i,a_i) 个数进行操作,并对答案产生贡献,对其他数字的操作是冗余的,据此可以写出 easy version 的代码如下:

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 200010;
int n;
int a[N];
void solve(){
    cin>>n;
    for (int i=1;i<=n;++i) cin>>a[i];
    sort(a+1,a+1+n);
    ll ans=0,len=0;
    for (int i=1;i<=n;++i){
        len++;
        if (a[i]>=len) ans+=a[i]-len;
        else len--;
    }
    cout<<ans<<endl;
}
int main(){
    int t;cin>>t;
    while(t--) solve();
    return 0;
}

回到 hard version,我们记录当前真正被进行操作的数字个数 len,并维护一个数组 d_i 表示小于等于 i 的数中,还可以被操作的数字个数,初始时 d_i=i

新加入一个数 x 时,我们假设强制选择 x 进行操作,那么 x 会让 i \in [x,n] 这一段区间的 d_i1

若此时 \forall i \in [1,n] ,d_i \geqslant 0,说明序列仍合法,那么将序列长度 len1 后,对答案有 x-len 的贡献。

否则,找到最小的使得 d_{pos}<0pos,一定有 x\leqslant pos。那么在不改变序列长度的情况下,留下 x 一定比留下 pos 优秀。因此我们把 i \in [pos,n] 这一段区间的 d_i1,表示撤销 pos 原本对 d_i 的贡献,这样操作对答案的贡献是 x-pos

可以发现 pos 上一定有之前被操作的数,否则设 y 为满足 y < pos 的第一个被操作的数,由于 y 的限制比 pos 更紧而选择的数字个数不变, 一定也满足 d_y < 0,那么 pos 就不是最小的数了。

上述过程涉及区间加和查询最小值,可以维护一颗线段树,查询 pos 时在线段树上二分。

每次加入一个数后修改答案即可。

时间复杂度 O(n \log n)

代码如下:

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 200010;
int n;
int a[N];
struct node{
    int mn,tag;
    #define ls p<<1
    #define rs p<<1|1
    #define mn(p) t[p].mn
    #define tag(p) t[p].tag
}t[N<<2];
void pushdown(int p){
    if (tag(p)!=0){
        mn(ls)+=tag(p); mn(rs)+=tag(p);
        tag(ls)+=tag(p); tag(rs)+=tag(p);
        tag(p)=0;
    }
}
void pushup(int p){
    mn(p)=min(mn(ls),mn(rs));
}
void build(int p,int l,int r){
    tag(p)=0;
    if (l==r) {
        mn(p)=l;
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid); build(rs,mid+1,r);
    pushup(p);
}
void update(int p,int l,int r,int pl,int pr,int d){
    if (l<=pl && pr<=r){
        mn(p)+=d; tag(p)+=d;
        return;
    }
    pushdown(p);
    int mid=(pl+pr)>>1;
    if (mid>=l) update(ls,l,r,pl,mid,d);
    if (mid+1<=r) update(rs,l,r,mid+1,pr,d);
    pushup(p);
}
int find(int p,int l,int r){
    if (mn(p)>=0) return 0;
    if (l==r) return l;
    pushdown(p);
    int mid=(l+r)>>1;
    if (mn(ls)<0) return find(ls,l,mid);
    else return find(rs,mid+1,r);
}
void solve(){
    cin>>n;
    for (int i=1;i<=n;++i) cin>>a[i];
    ll ans=0,len=0;
    build(1,1,n);
    for (int i=1;i<=n;++i){
        update(1,a[i],n,1,n,-1);
        int pos=find(1,1,n);
        if (pos){//ai对之后的答案有更新
            ans+=a[i]-pos;
            update(1,pos,n,1,n,1);
        }
        else {
            len++;
            ans+=a[i]-len;
        }
        printf("%lld ",ans);
    }
    puts("");
}
int main(){
    int T;cin>>T;
    while(T--) solve();
    return 0;
}