题解 CF1344D 【Résumé Review】

· · 题解

提供一个和官方题解不同的算法。

注意到如果把题目要求中的“整数”和b_i\le a_i这两个限制去掉,就是一个基本的求多元函数最大值的形式。所以可以先假设b_i可以取任何的非负实数,求出最优解,然后在其基础上调整出答案。

求可导的多元函数的最大值,一般都可以用拉格朗日乘数法来搞(维基百科)。对于这道题而言,可以列出这个方程组:

\frac{\mathrm{d}}{\mathrm{d}b_i}\left(\sum_{i=1}^{n}b_i\left(a_i-b_i^2\right)-\lambda\left(\sum_{i=1}^{n}b_i-k\right)\right)=0,i\in [1,n]

化简一下就是:

a_i-3b_i^2=\lambda,i\in [1,n]

发现因为b_i\ge 0,左边是单调的,这和标准解法利用的差分之后单调递减的性质差不多。

由于是单调的,可以二分\lambda,通过\lambda解出所有b_i,然后比较\sum{b_i}k的大小决定该往哪半边走。这样简化版的问题就做完了,时间复杂度O(n\log{(\max{a_i})})

然后考虑加回去b_i\le a_i的限制。把a_i\lt b_ia_i-3b_i^2=\lambda联立,可以解出只有a_i小于某个值时才会超出限制。所以把a数组从小到大排序,不满足限制的一定是开头的一部分数。

如果在前一步二分的时候,把二分的下界定为\max{(a_i-3a_i^2)},容易看出此时求得的方案一定满足b_i\le a_i的限制。然而这个二分出的方案不一定是合法解,因为如果最后求出的\lambda恰好等于下界时,此时可能\sum b_i<k

所以我们在外层再套一个二分,二分最优解中是最小的midb_i卡到了上限,只求解a[mid+1\dots n]这些数,通过内层二分计算出的\sum b_i是否<k-\sum_{i=1}^{mid}a_i来判断该往哪半边走。

现在求出的方案是满足限制的实数最优解。感性理解一下就知道整数的最优解ans_i\in \{\lfloor b_i\rfloor,\lceil b_i\rceil\}(这个性质我不会严谨证明)。所以我们一开始先赋值ans_i=\lfloor b_i\rfloor,计算出rest=k-\sum{ans_i},然后选择出restans_i加上1

这一步很简单,把每一个ans_i+1时对答案增加的贡献排个序,选择最大的rest个就可以了。要注意的是如果ans_i=a_i就不能再+1了。

最后把ans数组还原成输入时的顺序即可。时间复杂度O(n\log{(\max{a_i})}\log{n})

这就是我比赛现场的做法。这个做法得到了TLE on test 10。赛后发现调整几个精度问题就过了:

代码:

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
using ld=long double;
int n,_,pos[1<<17];ll k,a[1<<17],b[1<<17];
struct st{ll x;int y;}c[1<<17],q[1<<17];
int chk(int x){
    ld l=-1e20,r=1e20,ik=k;int ok=0;
    for(int i=1;i<=x;i++)ik-=a[i];
    if(ik<0)return 0;
    for(int i=x+1;i<=n;i++)l=max(l,(ld)a[i]-3*a[i]*a[i]),r=min(r,(ld)a[i]);
    while(r-l>1){
        ld mid=(l+r)/2,tot=0;
        for(int i=x+1;i<=n;i++)tot+=sqrtl((a[i]-mid)/3);
        if(tot>ik)l=mid,ok=1;else r=mid;
    }
    for(int i=x+1;i<=n;i++)b[i]=sqrtl((a[i]-l)/3);
    return !ok;
}
int main(){
    scanf("%d%lld",&n,&k);
    for(int i=1;i<=n;i++)scanf("%lld",a+i),c[i]={a[i],i};
    sort(c+1,c+1+n,[](st a,st b){return a.x<b.x;});
    for(int i=1;i<=n;i++)pos[c[i].y]=i;
    sort(a+1,a+1+n);
    int l=0,r=n+1,mid;
    while(l<r){
        mid=l+r>>1;
        if(chk(mid))l=mid+1;
        else r=mid;
    }
    for(int i=1;i<=l;i++)b[i]=a[i];
    chk(l);
    for(int i=1;i<=n;i++){k-=b[i];if(b[i]<a[i])q[++_]={a[i]-3*b[i]*b[i]-3*b[i],i};}
    sort(q+1,q+1+_,[](st a,st b){return a.x>b.x;});
    for(int i=1;i<=k;i++)b[q[i].y]++;
    for(int i=1;i<=n;i++)printf("%lld%c",b[pos[i]]," \n"[i==n]);
    return 0;
}