题解:P10673 【MX-S1-T2】催化剂

· · 题解

因为要使每个小朋友拿到的糖种类尽量多,所以我们可以让一个小朋友取所有有剩余的糖各一颗,然后剩下的糖的总数就是答案。

那么对于第 i 种糖,答案就是 f_i-k(f_i\geq k),其中 f_i 表示目前还有多少颗第 i 种糖,但是我们会发现这个方向很难继续优化。

于是我们转换思路,记录 s_i 表示有 i 颗糖的糖果种类的数量,那么答案就是 (i-k)\times s_i(i>k),同时注意到 \sum_{i=x}^y i\times s_i\sum_{i=x}^y s_i 都可以用线段树在 O(\log n) 时间内得出。

故对于前两个操作就是线段树单点修改,最后一个操作就是线段树区间查询,复杂度是 O((n+q)\log n)

值得注意的是,正常数组需要开 2\times 10^6,因为有可能会出现非常多的一号操作。线段树记得开四倍。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const long long N=2e6;
int n,q,k,a[N+10],f[N+10],s[N+10],tree[4*N+10],t[4*N+10];
int read(){
    int x=0;char ac=getchar();
    while(ac<'0' || ac>'9') ac=getchar();
    while(ac>='0' && ac<='9') x=x*10+ac-'0',ac=getchar();
    return x;
}
void push_up(int id){
    tree[id]=tree[id*2]+tree[id*2+1];
    t[id]=t[id*2]+t[id*2+1];
}
void build_tree(int id,int l,int r){
    if(l==r){
        tree[id]=s[l]*l,t[id]=s[l];
        return ;
    }
    int mid=(l+r)/2;
    build_tree(id*2,l,mid);
    build_tree(id*2+1,mid+1,r);
    push_up(id);
}
void change(int id,int l,int r,int x){
    if(l==r){
        tree[id]=s[l]*l,t[id]=s[l];
        return ;
    }
    int mid=(l+r)/2;
    if(x<=mid) change(id*2,l,mid,x);
    else change(id*2+1,mid+1,r,x);
    push_up(id);
}
int cal(int id,int l,int r,int x,int y){
    if(x<=l && r<=y) return tree[id]-k*t[id];
    int mid=(l+r)/2,ans=0;
    if(x<=mid) ans+=cal(id*2,l,mid,x,y);
    if(y>mid) ans+=cal(id*2+1,mid+1,r,x,y);
    return ans;
}
signed main(){
    n=read(),q=read();
    for(int i=1;i<=n;i++)
        a[i]=read(),f[a[i]]++;
    for(int i=1;i<=n;i++)
        s[f[i]]++;
    //for(int i=1;i<=n;i++)
    //    printf("%d ",s[i]);
    build_tree(1,1,N);
    while(q--){
        int op=read();
        if(op==1){
            int x=read();
            s[f[x]]--,change(1,1,N,f[x]),f[x]++,s[f[x]]++,change(1,1,N,f[x]);
        }
        else if(op==2){
            int x=read();
            s[f[x]]--,change(1,1,N,f[x]),f[x]--,s[f[x]]++,change(1,1,N,f[x]);
        }
        else if(op==3){
            k=read();
            //printf("%d %d %d ",tree[1],t[1],cal(1,1,n,1,k));
            printf("%lld\n",cal(1,1,N,k+1,N));
        }
    }
    return 0;
}