P10093 [ROIR 2022 Day 2] 礼物 题解

· · 题解

这题应该有紫了吧。

考虑钦定第 k 大为 x,定义序列 b_i=[a_i\ge x],则所有合法区间 [l,r] 的条件为 \sum_{i=l}^rb_i=k,把全 0 的段缩一下,则总共合法的区间数量只有 nk。则转化为,有 nk 个子问题,形如

l \in[x_i,y_i],r \in [p_i,q_i],\max(\sum_{i=l}^{r}a_i-\sum_{i=y_i+1}^{p_i-1}a_ib_i)

稍微画一下,就是要求一段区间的前缀后缀最大和。这个可以用 ST 表做到 O(n\log_2n)-O(1),然后确定这些区间段,可以用 set 和链表维护。

时间复杂度 O(n\log_2n+nk),空间 O(n\log_2n)

没人放代码啊,那为放一下:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <set>
#define int long long
using namespace std;

int n,k;
int a[200005];
int b[200005];
int c[200005];
int s[200005];
int lf[200005];
int rt[200005];
int lg[200005];
set <int> st;

struct ST{
    int a[200005];
    int s[200005];
    int ms[21][200005];
    inline void init(){
        for(int i=1;i<=n;i++) s[i]=s[i-1]+a[i],ms[0][i]=a[i];
        for(int j=1;j<=lg[n];j++)
            for(int i=1;i+(1<<j)-1<=n;i++){
                int l=i,r=i+(1<<(j-1))-1;
                ms[j][i]=max(ms[j-1][i],s[r]-s[l-1]+ms[j-1][i+(1<<(j-1))]);
            }
        return ;
    }
    inline int ask(int l,int r){
        if(l>r) return 0;
        int len=lg[r-l+1];
        return max(ms[len][l],ms[len][r-(1<<len)+1]+s[r-(1<<len)]-s[l-1]);
    }
}t1,t2;

inline void in(int &n){
    n=0;
    char c=getchar();bool ok=c=='-';
    while(c<'0' || c>'9') c=getchar(),ok|=c=='-';
    while(c>='0'&&c<='9') n=n*10+c-'0',c=getchar();
    n=(ok?-n:n);
    return ;
}

inline void relink(int pos){
    auto it=st.insert(pos).first,ti=it;
    ti--;
    it++;
    int x=*ti,y=*it;
    lf[pos]=x,rt[pos]=y;
    rt[x]=pos,lf[y]=pos;
    return ;
}

signed main(){
    in(n),in(k);
    for(int i=1;i<=n;i++){
        in(a[b[i]=i]);
        s[i]=s[i-1]+a[i];
        t1.a[i]=t2.a[n-i+1]=a[i];
        c[i]=a[i];
    }
    if(k==0){
        int s=0,ans=-1e18;
        for(int i=1;i<=n;i++){
            s=max(0ll,s)+a[i];
            ans=max(ans,s);
        }
        printf("%lld\n",ans);
        return 0;
    }
    rt[0]=n+1;
    lf[n+1]=0;
    rt[n+1]=n+1;
    st.insert(0);
    st.insert(n+1);
    for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
    t1.init(),t2.init();
    sort(b+1,b+1+n,[](int x,int y){return a[x]>a[y];});
    int ans=-1e18;
    for(int j=1;j<=n;j++){
        int i=b[j],cc=0,sum=a[i],u=i;
        relink(i);
        c[++cc]=i,sum=a[i];
        while(rt[u]<=n&&cc<k) c[++cc]=rt[u],sum+=a[rt[u]],u=rt[u];
        u=i;
        int ccc=cc;
        while(cc<k&&lf[u]) sum+=a[lf[u]],u=lf[u],cc++;
        if(cc<k) continue;
        cc=ccc;
        while(u&&cc){
            int v=c[cc--];
            int ss=s[v]-s[u-1]-sum;
            ss+=max(0ll,t2.ask(n-u+2,n-lf[u]));
            ss+=max(0ll,t1.ask(v+1,rt[v]-1));
            ans=max(ans,ss);
            sum-=a[v];
            u=lf[u];
            sum+=a[u];
        }
    }
    printf("%lld\n",ans);

    return 0;
}