[Algo Beat Contest 002 E]题解

· · 题解

首先易得一个结论:所有操作都在同一个数组上进行,一定可以得到最优解之一。证明如下:

假设有两个数组,两次操作。对第一个数组第一次操作对答案的贡献为 x_1,第二次操作贡献为 x_2。类似地,第二个数组分别对应 y_1,y_2。易得 x_1 \le x_2,y_1 \le y_2,所以 x_1+y_1 \le \max(x_1+x_2,y_1+y_2)。因此,两个数组各操作一次,一定不会是唯一最优解。把这个结论扩展到更多个数组上,同样也是成立的。

因此现在问题变为:一次操作,将某一个数组中的所有数增加 QK,求最大答案。

K 为正数时,我们可以用 multiset 存放前 M 大的数。枚举每一个数组,如果 A_{i,j} 增加 QK 后超过了第 M 大数,则在 multiset 中用 A_{i,j}+QK 替换该数,并更新答案。操作过程中需要记录这个数组多加入或删除的数,方便操作完后复原,并避免每次 O(M) 复原而超时。最后取操作每个数组得到的答案的最大值即可。

K 为负数时,与前面不同的是,这次 A_{i,j}+QK 比原来小。类似地,我们用 multiset 存放前 N-M 小的数,操作时不断更新前 N-M 小的数。注意当 SL=M 时需要特判,否则会因对空 multiset 取 begin() 而 RE/TLE。

SL=\sum_{i=1}^n L_i,则总时间复杂度为 O(SL \log SL)。最后就是一些细节问题,详见代码。

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2e5+5;
int n,m,q,k,b[N],l[N],tot,cur,ans=-1e18,sum;
vector<int> a[N],era,ins;
multiset<int> S;
void work1(){
    sort(b+1,b+tot+1,greater<int>());
    for(int i=1;i<=m;i++)
        S.insert(b[i]),cur+=b[i];
    int ncur=cur;
    for(int i=1;i<=n;i++){
        era.clear();
        ins.clear();
        for(int j=1;j<=l[i];j++)
            if(a[i][j]+q*k>*S.begin()){
                int tmp=max(*S.begin(),a[i][j]);
                cur=cur+a[i][j]+q*k-tmp;
                S.erase(S.find(tmp));
                era.push_back(tmp);
                S.insert(a[i][j]+q*k);
                ins.push_back(a[i][j]+q*k);
            }
        ans=max(ans,cur);
        cur=ncur;
        for(auto v:era) S.insert(v);
        for(auto v:ins) S.erase(S.find(v));
    }
    cout<<ans;
}
void work2(){
    sort(b+1,b+tot+1);
    for(int i=1;i<=tot-m;i++)
        S.insert(-b[i]),cur+=b[i];
    int ncur=cur;
    for(int i=1;i<=n;i++){
        if(m==tot){
            ans=max(ans,sum+l[i]*q*k);
            continue;
        }
        era.clear();
        ins.clear();
        for(int j=1;j<=l[i];j++)
            if(a[i][j]+q*k<-*S.begin()){
                int tmp=min(-*S.begin(),a[i][j]);
                cur=cur+a[i][j]+q*k-tmp;
                S.erase(S.find(-tmp));
                era.push_back(-tmp);
                S.insert(-a[i][j]-q*k);
                ins.push_back(-a[i][j]-q*k);
            }
        ans=max(ans,sum+l[i]*q*k-cur);
        cur=ncur;
        for(auto v:era) S.insert(v);
        for(auto v:ins) S.erase(S.find(v));
    }
    cout<<ans;
}
signed main(){
    cin>>n>>m>>q>>k;
    for(int i=1;i<=n;i++){
        a[i].push_back(0);
        scanf("%lld",&l[i]);
        for(int j=1,x;j<=l[i];j++){
            scanf("%lld",&x);
            a[i].push_back(x);
            b[++tot]=x;
            sum+=x;
        }
    }
    if(k>0) work1();
    else work2();
    return 0;
}