题解 CF1344D 【Résumé Review】

cosmicAC

2020-05-08 20:56:32

Solution

提供一个和官方题解不同的算法。 注意到如果把题目要求中的“整数”和$b_i\le a_i$这两个限制去掉,就是一个基本的求多元函数最大值的形式。所以可以先假设$b_i$可以取任何的非负实数,求出最优解,然后在其基础上调整出答案。 求可导的多元函数的最大值,一般都可以用[拉格朗日乘数法](https://baike.baidu.com/item/%E6%8B%89%E6%A0%BC%E6%9C%97%E6%97%A5%E4%B9%98%E6%95%B0%E6%B3%95)来搞([维基百科](https://zh.wikipedia.org/wiki/%E6%8B%89%E6%A0%BC%E6%9C%97%E6%97%A5%E4%B9%98%E6%95%B0))。对于这道题而言,可以列出这个方程组: $$\frac{\mathrm{d}}{\mathrm{d}b_i}\left(\sum_{i=1}^{n}b_i\left(a_i-b_i^2\right)-\lambda\left(\sum_{i=1}^{n}b_i-k\right)\right)=0,i\in [1,n]$$ 化简一下就是: $$a_i-3b_i^2=\lambda,i\in [1,n]$$ 发现因为$b_i\ge 0$,左边是单调的,这和标准解法利用的差分之后单调递减的性质差不多。 由于是单调的,可以二分$\lambda$,通过$\lambda$解出所有$b_i$,然后比较$\sum{b_i}$和$k$的大小决定该往哪半边走。这样简化版的问题就做完了,时间复杂度$O(n\log{(\max{a_i})})$。 然后考虑加回去$b_i\le a_i$的限制。把$a_i\lt b_i$和$a_i-3b_i^2=\lambda$联立,可以解出只有$a_i$小于某个值时才会超出限制。所以把$a$数组从小到大排序,不满足限制的一定是开头的一部分数。 如果在前一步二分的时候,把二分的下界定为$\max{(a_i-3a_i^2)}$,容易看出此时求得的方案一定满足$b_i\le a_i$的限制。然而这个二分出的方案**不一定是**合法解,因为如果最后求出的$\lambda$恰好等于下界时,此时可能$\sum b_i<k$。 所以我们在外层再套一个二分,二分最优解中是最小的$mid$个$b_i$卡到了上限,只求解$a[mid+1\dots n]$这些数,通过内层二分计算出的$\sum b_i$是否$<k-\sum_{i=1}^{mid}a_i$来判断该往哪半边走。 现在求出的方案是满足限制的实数最优解。感性理解一下就知道整数的最优解$ans_i\in \{\lfloor b_i\rfloor,\lceil b_i\rceil\}$(这个性质我不会严谨证明)。所以我们一开始先赋值$ans_i=\lfloor b_i\rfloor$,计算出$rest=k-\sum{ans_i}$,然后选择出$rest$个$ans_i$加上$1$。 这一步很简单,把每一个$ans_i+1$时对答案增加的贡献排个序,选择最大的$rest$个就可以了。要注意的是如果$ans_i=a_i$就不能再$+1$了。 最后把$ans$数组还原成输入时的顺序即可。时间复杂度$O(n\log{(\max{a_i})}\log{n})$。 这就是我比赛现场的做法。这个做法得到了TLE on test 10。赛后发现调整几个精度问题就过了: - 内层二分的结束条件是$r-l\le\epsilon$,这个$\epsilon=1$。因为二分边界最大可以达到$3\times10^{18}$,此时`long double`的精度只到个位,如果$\epsilon$比精度还要小就会死循环。 - 最后一步按照增加的贡献排序时,不能用`long double`算出来两个贡献再相减,要手动化简一下式子。因为两个相近的大数相减会大大降低精度。 代码: ```cpp #include<bits/stdc++.h> using namespace std; using ll=long long; using ld=long double; int n,_,pos[1<<17];ll k,a[1<<17],b[1<<17]; struct st{ll x;int y;}c[1<<17],q[1<<17]; int chk(int x){ ld l=-1e20,r=1e20,ik=k;int ok=0; for(int i=1;i<=x;i++)ik-=a[i]; if(ik<0)return 0; for(int i=x+1;i<=n;i++)l=max(l,(ld)a[i]-3*a[i]*a[i]),r=min(r,(ld)a[i]); while(r-l>1){ ld mid=(l+r)/2,tot=0; for(int i=x+1;i<=n;i++)tot+=sqrtl((a[i]-mid)/3); if(tot>ik)l=mid,ok=1;else r=mid; } for(int i=x+1;i<=n;i++)b[i]=sqrtl((a[i]-l)/3); return !ok; } int main(){ scanf("%d%lld",&n,&k); for(int i=1;i<=n;i++)scanf("%lld",a+i),c[i]={a[i],i}; sort(c+1,c+1+n,[](st a,st b){return a.x<b.x;}); for(int i=1;i<=n;i++)pos[c[i].y]=i; sort(a+1,a+1+n); int l=0,r=n+1,mid; while(l<r){ mid=l+r>>1; if(chk(mid))l=mid+1; else r=mid; } for(int i=1;i<=l;i++)b[i]=a[i]; chk(l); for(int i=1;i<=n;i++){k-=b[i];if(b[i]<a[i])q[++_]={a[i]-3*b[i]*b[i]-3*b[i],i};} sort(q+1,q+1+_,[](st a,st b){return a.x>b.x;}); for(int i=1;i<=k;i++)b[q[i].y]++; for(int i=1;i<=n;i++)printf("%lld%c",b[pos[i]]," \n"[i==n]); return 0; } ```