P7856 「EZEC-9」模糊众数

· · 题解

对于一组询问,我们记 v 为询问希望得到的众数,x 为最后众数出现的次数。这个问题可以用费用流来描述,对坐标离散化并预留足够的溢出位置后,相邻两个点连一条边,v 这个位置向汇点连一条带上下界的边,其他地方连一条带上界的边。于是我们可以得到一个结论,令 f(x)=y 表示答案,则 f 是一个下凸函数。当然不用费用流也可以得到类似的结论。

现在解决给定 v,x 求答案。根据费用流的建模,要让 x 个数变成 v,一定可以看作是把最近的 x 个数变成 v。所以我们只需要解决对于序列的每个前缀和后缀,需要多少次操作才能使得众数的出现次数不超过 x

考虑根号分治。令块长为 B,对于 x\le B 的部分,每个 xO(n) 预处理。对于 x>B 的部分,显然不合法的位置至多有 \frac nB 个,总变化量是 O(\frac nB) 级别的。

我们枚举每个 x\le B 并更新答案。剩下的部分使用二分,二分到最后一个斜率为负的位置,就是最小值。总复杂度 O(Bn+\frac nB\log n),取 B=\sqrt {n\log n} 即可做到 O(n\sqrt{n\log n})

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
const int MAXN=1e5+5;
int n,q,a[MAXN],B,op[MAXN][2];
int all[MAXN],cnt[MAXN],_cnt[MAXN];
ll sum[MAXN];
ll f1[MAXN],g1[MAXN],f2[MAXN],g2[MAXN];
void Solve(int x,ll *f,ll *g){
    int l=0,c=0;
    ll res=0;
    for(int i=1; i<=n; i++){
        while(l<a[i]){
            if(c<=x){
                c=0;
                break;
            }
            c-=x;
            res+=c;
            l++;
        }
        l=a[i];
        c++;
        ll k=c/x;
        f[i]=res+c*k-x*k*(k+1)/2;
    }
    static int sx[MAXN],sy[MAXN],top;
    top=0;
    res=0;
    for(int i=n,j=n+1; i; i--){
        while(j>n||all[j]!=a[i]){
            j--;
            top++;
            sx[top]=all[j];
            sy[top]=0;
        }
        res+=sx[top]-a[i];
        if(++sy[top]==x) top--;
        g[i]=res;
    }
    return ;
}
vector<int> vec;
int chg[MAXN];
ll Calc1(int w,int x){
    if(!w) return 0;
    int _w=lower_bound(a+1,a+n+1,a[w])-a;
    _w=w-_w+1;
    w=lower_bound(all+1,all+n+1,a[w])-all;
    chg[*chg=1]=w;
    cnt[w]=_w;
    ll res=0;
    for(int i=0; i<vec.size(); i++){
        int k=vec[i];
        if(k>=w) break;
        if(cnt[k]>x) chg[++*chg]=k;
        while(k<w&&cnt[k]>x){
            chg[++*chg]=k+1;
            cnt[k+1]+=cnt[k]-x;
            res+=cnt[k]-x;
            cnt[k]=x;
            k++;
        }
    }
    while(cnt[w]>x){
        res+=cnt[w]-x;
        cnt[w]-=x;
    }
    for(int i=1; i<=*chg; i++)
        cnt[chg[i]]=_cnt[chg[i]];
    return res;
}
ll Calc2(int w,int x){
    if(w>n) return 0;
    int _w=upper_bound(a+1,a+n+1,a[w])-a-1;
    _w=_w-w+1;
    w=lower_bound(all+1,all+n+1,a[w])-all;
    chg[*chg=1]=w;
    cnt[w]=_w;
    ll res=0;
    for(int i=lower_bound(vec.begin(),vec.end(),w)-vec.begin(); i<vec.size(); i++){
        int k=vec[i];
        if(cnt[k]>x) chg[++*chg]=k;
        while(cnt[k]>x){
            chg[++*chg]=k+1;
            cnt[k+1]+=cnt[k]-x;
            res+=cnt[k]-x;
            cnt[k]=x;
            k++;
        }
    }
    for(int i=1; i<=*chg; i++)
        cnt[chg[i]]=_cnt[chg[i]];
    return res;
}
bool vis[MAXN];
ll ans[MAXN];
int main(){
    scanf("%d%d",&n,&q);
    B=min(800,n);
    for(int i=1; i<=n; i++)
        scanf("%d",a+i);
    if(n==1){
        while(q--){
            int x;
            scanf("%d",&x);
            if(x>=a[1]) printf("%d\n",x-a[1]);
            else puts("-1");
        }
        return 0;
    }
    sort(a+1,a+n+1);
    for(int i=1; i<=n; i++)
        all[i]=max(all[i-1]+1,a[i]),sum[i]=sum[i-1]+a[i];
    for(int i=1,j=1; i<=n; i++){
        while(all[j]!=a[i]) j++;
        cnt[j]++;
    }
    for(int i=1; i<=n; i++)
        if(cnt[i]>B) vec.push_back(i);
    memcpy(_cnt,cnt,sizeof(cnt));
    Solve(B,f1,g1);
    Solve(B-1,f2,g2);
    for(int i=1; i<=q; i++){
        scanf("%d",op[i]);
        int v=op[i][0];
        int w=upper_bound(a+1,a+n+1,v)-a-1,x;
        op[i][1]=w;
        if(w>=B){
            x=B;
            ll y1=f1[w-x]+g1[w+1]+(1ll*v*x-(sum[w]-sum[w-x]));
            x--;
            ll y2=f2[w-x]+g2[w+1]+(1ll*v*x-(sum[w]-sum[w-x]));
            if(y1>=y2) vis[i]=1,ans[i]=-1;
            else{
                int l=B,r=w;
                while(l<r){
                    int mid=l+r>>1;
                    x=mid;
                    y1=Calc1(w-x,x)+Calc2(w+1,x)+(1ll*v*x-(sum[w]-sum[w-x]));
                    x++;
                    y2=Calc1(w-x,x)+Calc2(w+1,x)+(1ll*v*x-(sum[w]-sum[w-x]));
                    if(y1<y2) r=mid;
                    else l=mid+1;
                }
                x=l;
                ans[i]=Calc1(w-x,x)+Calc2(w+1,x)+(1ll*v*x-(sum[w]-sum[w-x]));
            }
        }else vis[i]=1,ans[i]=-1;
    }
    for(int x=1; x<=B; x++){
        Solve(x,f1,g1);
        bool ok=0;
        for(int i=1; i<=q; i++)
            if(vis[i]){
                int v=op[i][0],w=op[i][1];
                if(x>w){
                    vis[i]=0;
                    continue;
                }
                ll y=f1[w-x]+g1[w+1]+(1ll*v*x-(sum[w]-sum[w-x]));
                if(x>1&&y>=ans[i]){
                    vis[i]=0;
                    continue;
                }
                ans[i]=y;
                ok=1;
            }
        if(!ok) break;
    }
    for(int i=1; i<=q; i++)
        printf("%lld\n",ans[i]);
    return 0;
}

之前写挂的原因:复制了一行代码忘改一点东西,两个函数对着写,结果第二个函数写反了。改进行修改的没修改。q 写成 n。总结:千万不要复制代码,总是忘改东西,也不要对着之前的代码写。