【题解】CF1355E Restorer Distance

· · 题解

前言

本文的主要目的是证明以最终高度为自变量,最小代价为因变量的函数是下凸函数

其他题解的证明没有考虑到可能产生的 X \le YX>Y 的交界XY 是下文证明中的未知数)。在此基础上,大多数题解的证明比较含糊;不含糊的@马保国官方号的题解的证明是错的(它计算 C' 的第三行漏掉了 AN 这一项,得出了 C'<0 这个错误结论)。

同时也希望各位读者:如果发现了本文中出现的任何错误,请立即在评论区指出或私信笔者,笔者将尽快回复和修改。

题解

首先,一次操作三相当于一次操作一和一次操作二,不妨令 M=\min(M,A+R)

那么,执行 X 次操作一和 Y 次操作二的最小代价为 (X-\min(X,Y))A+(Y-\min(X,Y))R+\min(X,Y) \cdot M
(即:将操作一和操作二尽可能配对,每对花费 M 的代价)

然后,设最终所有砖块的高度都变为了 H,那么上述的 X,Y 应满足 X=\sum_{h_i \le H}(H-h_i)Y=\sum_{h_i>H}(h_i-H)
(我们显然不会对同一砖块的高度又加又减)

f(H) 表示最终所有砖块的高度都变为了 H 的最小代价,则有:

\begin{cases} (Y-X)R+XM & (X \le Y) \\ (X-Y)A+YM & (X > Y) \end{cases}

(就是将第二段那个式子中的 \min(X,Y)XY 代替)

为了方便,下文用 f(H) 相关的变量加上一撇表示 f(H+1) 相对应的变量

设初始高度 \le H 的砖块恰有 P 个,则 X'=X+PY'=Y-(N-P)

说一下 X',Y' 的计算方法:先将 X,Y 式子中的 H 变为 H+1,然后将 h_i=H+1 的部分从 Y 移到 X,就变成了 X',Y'。(其中 h_i=H+1 的部分在 XY 中都为 0,无影响)

令函数 f'(H)=f(H+1)-f(H),分类讨论 f'(H) 的值:

X \le YX' \le Y',则

f(H+1)-f(H)=(Y'-X')R+X'M-(Y-X)R-XM=M \cdot P-NR \quad (1)

X \le YX'>Y',则

f(H+1)-f(H)=(X'-Y')A+Y'M-(Y-X)R-XM=M \cdot P+(X-Y+N)A+(Y-X-N)M+(X-Y)R \quad (2)

X>YX'>Y',则

f(H+1)-f(H)=(X'-Y')A+Y'M-(X-Y)A-YM=M \cdot P+NA-NM \quad (3)

请注意:随着 H 的增大,X 增大,Y 减小,P 增大,(1)(2)(3) 三种情况依次出现。

由于 (1)(2)(3) 中P 以外都是定值,所以 f' 函数在 (1)(2)(3) 这三段分别单调不降

(2)-(1) 得 M \Delta P+(X-Y+N)(A+R-M)。由于 \Delta P \ge 0X-Y+N=X'+Y'+2(N-P)>0M \le A+R,所以 (2)-(1) \ge 0,所以 (1)(2) 合起来也单调不降

(3)-(2) 得 M \Delta P-(X-Y)(A+R-M)。由于 \Delta P \ge 0X \le YM \le A+R,所以 (3)-(2) \ge 0,所以 (2)(3) 合起来也单调不降

综上所述,f' 在定义域内单调不降—— f 是下凸函数

P.S.

学过导数的人都知道 f' 就是 f 的导函数(注意 f 自变量为整数),而下凸函数的一种等价定义是其导函数单调不降,或表述为其二阶导函数恒 \ge 0

如果你没学过导数,可以将其简单地理解为函数 f 变化量呈上升趋势,画个图就可以知道这可以判定下凸函数。

代码

P.S. 按本代码中整数三分的写法,边界条件应为 $[l,r]$ 的长度 $<3$,**若写法不同应具体情况具体处理**。 ``` cpp #include<bits/stdc++.h> using namespace std; int N,A,R,M; const int max_N=1e5+5; int h[max_N]; long long sum[max_N]; inline long long calc(int H) { int p=upper_bound(h+1,h+N+1,H)-h-1; long long X=1ll*p*H-sum[p],Y=(sum[N]-sum[p])-1ll*(N-p)*H,c=min(X,Y); return A*(X-c)+R*(Y-c)+M*c; } int main() { scanf("%d%d%d%d",&N,&A,&R,&M); M=min(M,A+R); for(int i=1;i<=N;++i) scanf("%d",h+i); sort(h+1,h+N+1); for(int i=1;i<=N;++i) sum[i]=sum[i-1]+h[i]; int l=h[1],r=h[N]; while(l<r) { int lmid=l+(r-l)/3,rmid=r-(r-l)/3; if(calc(lmid)<calc(rmid)) r=rmid-1; else l=lmid+1; } printf("%lld\n",calc(l)); return 0; } ```