P10240 题解

· · 题解

Problem Link

题目大意

给定 n 个元素,每个元素有权重 a_i,进行若干轮操作,每次选出最多的元素使得 \sum a_i\le m,多种方案选字典序最大一组方案,选出的元素都删除,求多少轮后所有元素被删空。

数据范围:n\le 50000

思路分析

首先考虑怎么选出最多元素,显然会按从小到大的顺序贪心取出前 k 个元素。

然后考虑怎么确定一组解,可以逐位贪心,即先最大化标号最小元素的位置,可以二分一个 x,那么我们就要求 [x,n] 范围内前 k 小元素和 \le m

由于我们要动态删除元素,因此可以树状数组套值域线段树树,求出一组解的复杂度 \mathcal O(k\log^3n)

由于 \sum k=n,因此总复杂度 \mathcal O(n\log ^3n)

从小到大贪心求 k 可以直接 std::multiset 维护。

时间复杂度:\mathcal O(n\log^3n)

代码呈现

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN=50005;
const ll inf=1e18;
int n,m,a[MAXN],id[MAXN],rk[MAXN],vals[MAXN];
struct Segt {
    static const int MAXS=MAXN*200;
    int ls[MAXS],rs[MAXS],siz[MAXS],tot;
    ll sum[MAXS];
    void ins(int u,int op,int l,int r,int &p) {
        if(!p) p=++tot;
        siz[p]+=op,sum[p]+=op*vals[u];
        if(l==r) return ;
        int mid=(l+r)>>1;
        if(u<=mid) ins(u,op,l,mid,ls[p]);
        else ins(u,op,mid+1,r,rs[p]);
    }
    ll qry(int k,int l,int r,vector<int>&P) {
        if(l==r) return vals[l];
        int mid=(l+r)>>1,c=0;
        for(int p:P) c+=siz[ls[p]];
        if(k<=c) {
            for(int&p:P) p=ls[p];
            return qry(k,l,mid,P);
        } else {
            ll s=0;
            for(int&p:P) s+=sum[ls[p]],p=rs[p];
            return qry(k-c,mid+1,r,P)+s;
        }
    }
    int rt[MAXN];
    void ins(int x,int u,int op) { for(;x;x&=x-1) ins(u,op,1,n,rt[x]); }
    ll qry(int k,int x) {
        int s=0; vector <int> P;
        for(;x<=n;x+=x&-x) s+=siz[rt[x]],P.push_back(rt[x]);
        if(s<k) return inf;
        return qry(k,1,n,P);
    }
}   T;
multiset <int> A;
int solve() {
    int s=0,c=0;
    for(int i:A) {
        if(s+i>m) return c;
        s+=i,++c;
    }
    return c;
}
signed main() {
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i) scanf("%d",&a[i]),id[i]=i;
    sort(id+1,id+n+1,[&](int x,int y){ return a[x]<a[y]; });
    for(int i=1;i<=n;++i) rk[id[i]]=i,vals[i]=a[id[i]];
    for(int i=1;i<=n;++i) A.insert(a[i]),T.ins(i,rk[i],1);
    int cnt=0;
    for(;A.size();++cnt) {
        int k=solve(),rem=m;
        for(int i=1;i<=k;++i) {
            int l=1,r=n,p=0;
            while(l<=r) {
                int mid=(l+r)>>1;
                ll z=T.qry(k-i+1,mid);
                if(z<=rem) p=mid,l=mid+1;
                else r=mid-1;
            }
            A.erase(A.find(a[p])),T.ins(p,rk[p],-1),rem-=a[p];
        }
    }
    printf("%d\n",cnt);
    return 0;
}