P12030 题解

· · 题解

题意略。

首先发现 John 一定选最大的 A 瓶,Nhoj 一定选最大的 B 瓶(因为它们增长量最大)。所以选定的 A 瓶是固定的,但 B 瓶会变的,难以维护,所以考虑换一种方式。

我们先将 A 瓶牛奶从大到小排序,加上 d,再对每个瓶子去减不超过 d 次,总共减 B \times d 次,使得值最小。

这个过程可以用线段树维护,具体来说,对于一段牛奶数相同的区间,有 \text{3} 种情况会使它停止减少:可以与下一段合并,B \times d 用完了,最左边的减少次数达到 d(因为它一定是该区间减的次数最多的)。

而一个数最多被合并一次,最多达到减少 d 次一次,B \times d 最多用光一次,所以整体时间复杂度为 O(n \log n)

upd:对于线段树上的操作,实际上可以用一个双端队列维护,对于每个位置记录时间戳,大于 d 就弹出。这样子再与下一段合并,只要记录最后退出时队列中的元素应是多少,这样瓶颈就在于排序了。

#include <bits/stdc++.h>
using namespace std;
const int mod=1e9+7,MAXN=1e5+10;
int n,A,B,a[MAXN],ans;
long long d;
struct Segment{
    struct tree{
        int val,add;
    }t[MAXN<<2];
    void build1(int l,int r,int rt)
    {
        t[rt].add=0;
        if(l==r)
        {
            t[rt].val=a[l];
            return;
        }
        int m=(l+r)>>1;
        build1(l,m,rt<<1);
        build1(m+1,r,rt<<1|1);
        t[rt].val=min(t[rt<<1].val,t[rt<<1|1].val);
    }
    void build2(int l,int r,int rt)
    {
        t[rt]={d,0};
        if(l==r) return;
        int m=(l+r)>>1;
        build2(l,m,rt<<1);
        build2(m+1,r,rt<<1|1);
    }
    void Pointupdate(int rt,int res)
    {
        t[rt].add+=res;
        t[rt].val-=res;
    }
    void pushdown(int rt)
    {
        if(t[rt].add)
        {
            Pointupdate(rt<<1,t[rt].add);
            Pointupdate(rt<<1|1,t[rt].add);
            t[rt].add=0;
        }
    }
    void update(int l,int r,int L,int R,int res,int rt)
    {
        if(L<=l && r<=R) return Pointupdate(rt,res);
        pushdown(rt);
        int m=(l+r)>>1;
        if(L<=m) update(l,m,L,R,res,rt<<1);
        if(m<R) update(m+1,r,L,R,res,rt<<1|1);
        t[rt].val=min(t[rt<<1].val,t[rt<<1|1].val);
    }
    int query(int l,int r,int L,int R,int rt)
    {
        if(L<=l && r<=R) return t[rt].val;
        pushdown(rt);
        int m=(l+r)>>1,ans=2e9;
        if(L<=m) ans=query(l,m,L,R,rt<<1);
        if(m<R) ans=min(ans,query(m+1,r,L,R,rt<<1|1));
        return ans;
    }
}T1,T2;

int main()
{
    cin>>n>>d>>A>>B;
    for(int i=1;i<=n;i++) cin>>a[i];
    sort(a+1,a+n+1,greater<int>());
    for(int i=A+1;i<=n;i++) ans=(ans+1ll*a[i]*a[i]%mod)%mod;
    n=A;
    for(int i=1;i<=n;i++) a[i]+=d;
    T1.build1(1,n,1),T2.build2(1,n,1);
    d=d*B;
    int r,l=1;
    for(int i=1;i<=n;i++)
    {
        if(a[i]==a[1]) r=i;
        else break;
    }
    while(d)
    {
        long long x=min(1ll*T1.query(1,n,l,r,1)-a[r+1],min(d/(r-l+1),1ll*T2.query(1,n,l,l,1)));
        d-=(r-l+1)*x;
        T2.update(1,n,l,r,x,1);
        T1.update(1,n,l,r,x,1);
        if(T1.query(1,n,l,r,1)!=a[r+1] && T2.query(1,n,l,l,1)>0)
        {
            if(d) T1.update(1,n,l,l+d-1,1,1);
            break;
        }
        else if(T1.query(1,n,l,r,1)!=a[r+1])
        {
            while(T2.query(1,n,l,l,1)==0 && l<=r) ++l;
            if(l>r)
            {
                r=l;
                for(int i=r+1;i<=n;i++)
                {
                    if(a[i]==a[r]) ++r;
                    else break;
                }
            }
        }
        else
        {
            r++;
            for(int i=r+1;i<=n;i++)
            {
                if(a[i]==a[r]) ++r;
                else break;
            }
        }
    }
    for(int i=1;i<=n;i++)
    {
        int x=T1.query(1,n,i,i,1);
        ans=(ans+1ll*x*x%mod)%mod;
    }
    cout<<ans<<endl;
    return 0;
}