P7390 「EZEC-6」造树
更好的阅读体验
考虑贪心,按权值排序,对于枚举到的节点
我们用一个指针
我写题解一般都写的很少的,但这次是验题的工作,所以写一下详细证明。
若当前枚举到的
若已知
所以对于上述过程,因为在任意时刻,全局最小的权值会和当前该权值能连边的最小权值相乘,所以最终结果一定是最大值。
可以反证,首先初始的边集为空集,肯定为最优解的子集,如果最小的权值
这个过程具有对称性,所以对权值的排序从小到大和从大到小是等价的。
而我们保证了
时间复杂度O(n),这里为了复杂度对应用的桶排,实际上用sort可以在2s内通过所有数据。
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <queue>
#define mp(x, y) make_pair(x, y)
#define file(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout);
using namespace std;
const int N = 1e7 + 10;
inline int read() {
bool sym = 0; int res = 0; char ch = getchar();
while (!isdigit(ch)) sym |= (ch == '-'), ch = getchar();
while (isdigit(ch)) res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
return sym ? -res : res;
}
struct NODE {int d, val;} dat[N], t[N];
int n, m = 5e5, last, cnt[N];
long long ans;
namespace STD {
unsigned seed;
unsigned rnd(unsigned x) {
x ^= x << 13; x ^= x >> 17; x ^= x << 5; return x;
}
int rad(int x, int y) {
seed = rnd(seed); return seed % (y - x + 1) + x;
}
void init() {
seed = read();
for(int i = 1; i <= n; i++) t[i].d = 1, t[i].val = rad(1, 500000);
for(int i = 1; i <= n - 2; i++) t[rad(1, n)].d++;
}
}
void sort() {
for (int i = 1; i <= n; i++) cnt[t[i].val]++;
for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for (int i = 1; i <= n; i++) dat[cnt[t[i].val]--] = t[i];
}
int main() {
int type = read(); n = read();
if (type == 1) STD::init(); else {
for (int i = 1; i <= n; i++) t[i].d = read();
for (int i = 1; i <= n; i++) t[i].val = read();
}
sort(); last = 1;
for (int i = 1, j = 2, k = 2; i <= n; i++, j = max(j, i + 1)) {
while (dat[i].d == 0 && i <= n) i++; if (i > n) break;
while (dat[i].d > 1) {
while (dat[j].d == 0) j++; if (dat[j].d > 1) last = j;
ans += 1ll * dat[i].val * dat[j].val; dat[i].d--; dat[j].d--; j++;
}
if (last == i) {
k = max(j, k); while (dat[k].d <= 1 && k < n) k++;
ans += 1ll * dat[i].val * dat[k].val; dat[i].d--; dat[k].d--; i = j;
while (i < k && dat[k].d > 1) {
ans += 1ll * dat[i].val * dat[k].val; dat[i].d--; dat[k].d--; i++;
}
if (dat[k].d == 1) last = i; else last = k; i--;
} else {
while (dat[j].d == 0 && j < n) j++; if (dat[j].d > 1) last = j;
ans += 1ll * dat[i].val * dat[j].val; dat[i].d--; dat[j].d--; j++;
}
}
printf("%lld", ans);
return 0;
}