CF626E Simple Skewness 题解

· · 题解

题意

给定 n 个非负整数,选一个子集,使得子集的平均值减去中位数最大。

求这个最大值。

\rm{Subtask} 1-\mathcal{O}(n^2)

因为需要考虑中位数,容易想到按照子集大小的奇偶性分类讨论。

子集大小为奇数

先来考虑奇数的情况。

容易想到 \mathcal{O}(n) 枚举中位数。

如上图,对数列按照升序排序。选定中位数后,我们每次从中位数左侧添加一个点,再从中位数右侧添加一个点。又因为我们想要子集的平均值最大,所以贪心地,每次分别从两侧取一个最大的数,也就是取最靠右的点,直到有一个区间被取完。

比如上图中,(4, 10)(3, 9)(2, 8) 两两为一组,每次分别被取出。

于是对于每个中位数,只需要 \mathcal{O}(n) 枚举两侧添加点的个数,就可以确定出选取的情况,求出平均值与中位数的差。

由于方案已经确定,直接输出即可。

这样可以得到整体复杂度为 \mathcal{O}(n^2) 的算法,期望得分未知。

子集大小为偶数

结论

对于大小为偶数的子集,观察可以得到结论:对于任意大小为偶数的子集,一定存在大小为奇数的子集比其更优,下面进行证明。

分析

按照上面的贪心策略任取一个大小为奇数的子集,将其升序排序,设为 \{a_1, a_2, \cdots a_{2n-1}\}.

现在向子集中添加一个数,使其大小变为偶数。根据上面提到的贪心策略,新加入的数 x 有以下两种情况:

  1. 紧邻中位数右侧,即 a_n\leq x\leq a_{n+1}.

在上面的图里,可以形象地看到,x 的第一种情况是选择 1,第二种情况是在 6, 7 中选一个。

第一种情况,显然不会使答案变大,因为 x 对平均值的影响太大了。

对于第二种情况,下面进行详细证明。

证明

为了方便,我们把加入 x 后的数列重新升序排序,不妨设为 \{b_1, b_2, \cdots b_{2n}\}.

加入 x 前的答案为:

A=\frac{\sum_{i=1}^{2n}b_i-b_{n+1}}{2n-1}-b_{n}

加入后的答案为:

B=\frac{\sum_{i=1}^{2n}b_i}{2n}-\frac{b_{n} + b_{n + 1}}{2}

两者作差并化简后得到:

A-B=\frac{\sum_{i=1}^{2n}b_i}{2n(2n-1)}+\frac{b_{n+1}-b_n}{2}-\frac{b_{n+1}}{2n-1}

又因为 B\geq 0,所以:

\frac{\sum_{i=1}^{2n}b_i}{2n}\geq \frac{b_n+b_{n+1}}{2}

即:

\frac{\sum_{i=1}^{2n}b_i}{2n(2n-1)}\geq \frac{b_n+b_{n+1}}{2(2n-1)}

代入 A-B 中得到:

\begin{aligned} A-B&\geq \frac{b_n+b_{n+1}}{2(2n-1)}+\frac{b_{n+1}-b_n}{2}-\frac{b_{n+1}}{2n-1}\\ &=\frac{b_n+b_{n+1}+(2n-1)(b_{n+1}-b_n)-2b_{n+1}}{2(2n-1)}\\ &=\frac{b_n+b_{n+1}+2n\cdot b_{n+1}-2n\cdot b_n-b_{n+1}+b_n-2b_{n+1}}{2(2n-1)}\\ &=\frac{(n-1)(b_{n+1}-b_n)}{(2n-1)}\geq 0 \end{aligned}

于是我们证明了:加入 x 前的答案一定不小于加入 x 后的答案,也就是:对于任意大小为偶数的子集,一定存在大小为奇数的子集比其更优

\mathcal{O}(n^2) 代码

以下代码输出的是:平均值与中位数的差的最大值

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 2e5 + 10;
const int inf = 0x3f3f3f3f3f3f3f3f;
inline int read() {}//快读
int n, a[MAXN], sum[MAXN];
signed main() {
    n = read();
    for(int i = 1; i <= n; i++) a[i] = read();
    sort(a + 1, a + n + 1);
    for(int i = 1; i <= n; i++) sum[i] = sum[i - 1] + a[i];
    double ans = (-1.0) * inf;
    for(int i = 1; i <= n; i++) {//枚举中位数,两侧同时添加点,尽可能靠右
        double mmax = (double)(a[i]);//只取中位数
        for(int j = 1; j <= min(i - 1, n - i); j++) {//subtask1 暴力枚举两侧加2*j个点
            int add_l = sum[i - 1] - sum[i - j - 1];
            int add_r = sum[n] - sum[n - j];
            mmax = max(mmax, (double)(add_l + add_r + a[i]) / (double)(2 * j + 1));
        }
        ans = max(ans, mmax - (double)(a[i]));
    }
    printf("%.2lf\n", ans);
    return 0;
}

\rm{Subtask} 2-\mathcal{O}(n\log n)

考虑如何把时间复杂度优化到 \mathcal{O}(n\log n) 级别。

在此基础上,设第 $i$ 次添加的两个元素**下标**分别为 $x_i$ 和 $y_i$. 当确定中位数时,答案只与平均值有关,考虑计算添加 $(x_i, y_i)$ 对平均值的贡献。 贡献,也就是添加这两个数之后,**平均值的变化量**,若平均值变大,则贡献为正,反之为负。 显然地,**随着 $i$ 的增长**,$a_{x_i}+a_{y_i}$ 会持续下降,**对平均值的贡献也会越来越少**。 我们最后一次选取的两个元素 $x_{len}, y_{len}$,必定满足贡献大于 $0$,并且是最接近 $0$ 的。 因此我们可以通过**二分**,找到贡献紧邻最接近0的数。 这样就确定了选择哪些数,前缀和计算答案即可。 具体实现时,我们不需要准确地计算出每两个元素的贡献。若**加入两个元素后的平均值大于原数列的平均值,则贡献为正**,否则贡献为负。这部分需要利用前缀和求出数列平均值。 对于输出方案,已经确定,直接输出即可。 这样就得到了 $\mathcal{O}(n\log n)$ 的算法,期望得分 $\rm{100pts}$. ## $\rm{100pts}$ 代码 ```cpp #include <bits/stdc++.h> #define int long long using namespace std; const int MAXN = 2e5 + 10; const int inf = 0x3f3f3f3f3f3f3f3f; inline int read() { int ret = 0; char ch = getchar(); while(ch < '0' || ch > '9') ch = getchar(); while(ch <= '9' && ch >= '0') { ret = ret * 10 + ch - '0'; ch = getchar(); } return ret; } int n; int a[MAXN], sum[MAXN]; signed main() { n = read(); for(int i = 1; i <= n; i++) a[i] = read(); sort(a + 1, a + n + 1); for(int i = 1; i <= n; i++) sum[i] = sum[i - 1] + a[i]; double ans = (-1.0) * inf; int ans_cnt = 0, ans_mid = 0; for(int i = 1; i <= n; i++) {//枚举中位数,两侧同时添加点,尽可能靠右 int l = 1, r = min(i - 1, n - i), mid; double mmax = (double)(a[i]); int cnt = 0; while(l <= r) { mid = (l + r) >> 1; int sum1 = sum[i - 1] - sum[i - mid] + sum[n] - sum[n - mid + 1]; //j = mid - 1,即加入前的平均值 int sum2 = sum[i - 1] - sum[i - mid - 1] + sum[n] - sum[n - mid]; //j = mid,加入两个元素后的平均值 if((double)(sum1 + a[i]) / (double)(2 * mid - 1) < (double)(sum2 + a[i]) / (double)(2 * mid + 1)) { if((double)(sum2 + a[i]) / (double)(2 * mid + 1) > mmax) { mmax = (double)(sum2 + a[i]) / (double)(2 * mid + 1); cnt = mid; } l = mid + 1; } else { r = mid - 1; } } if(mmax - double(a[i]) > ans) { ans = mmax - (double)(a[i]); ans_cnt = cnt; ans_mid = i; } } cout << ans_cnt * 2 + 1 << endl; for(int i = ans_mid - ans_cnt; i <= ans_mid; i++) cout << a[i] << " "; for(int i = n - ans_cnt + 1; i <= n; i++) cout << a[i] << " "; return 0; }```