UVA1205 Color a tree

· · 题解

一道经典的贪心题目,2021 年再看颇有感触。

题目大意

有一棵 n 个结点的树,每个节点都有一个权值 a_i。现在要将这棵树的结点全部染色,规定如下:

  1. 根节点 r 可以随时被染色

  2. 对于其他节点,其在被染色之前其父亲节点必须已经染色。

k 次染色的代价为 k×a_ik 类似时间戳。求将整棵树染色的最小总代价。

大体思路

考虑贪心:显然,对于树中除了根节点外权值最大的结点,必然会在其父亲节点被染色后立刻被染色。因此,树中权值最大的结点以及其父节点的染色时连续进行的。受到 NOIWC 2021 T1 的启发,考虑将这样的结点合并起来。

假设有权值为 x,y,z 的三个节点,且 x,y 的染色是连续的,则存在两种方案:

  1. x, y, 再染 z,其代价为 W_1=x+2y+3z

  2. x, y, 先染 z,其代价为 W_2=2x+3y+z

\text{if}\ W_1-W_2=2z-(x+y)≥0\ \text{Then} \ \dfrac{x+y}2≤z

因此,两个权值为 x, y 的结点等价于两个权值为 \dfrac{x+y}2 的结点。

由此可以推广到任意多个结点:对于 ma_1,a_2,⋯,a_m 连续染色,nb_1,b_2,⋯,b_n 连续染色,两种代价分别为

W_1 =\sum_{i=1}^m i\times a_i+\sum_{i=1}^n (i+m)\times b_i,\ W_2 =\sum_{i=1}^n i\times b_i+\sum_{i=1}^m (i+n)\times a_i \text{if}\ W_1\le W_2,\ \text{Then}\ W_1-W_2=m\times\sum_{i=1}^n b_i-n\times\sum_{i=1}^m a_i\le 0 \therefore \dfrac{a_1+a_2+\cdots +a_m}m\ge \dfrac{b_1+b_2+\cdots +b_n}n

因此,我们可以记录一个点是由多少个点合并而成,然后得到其等效权值:即

\dfrac{\text{该点包含的原始权值之和}}{\text{其包含的原始结点数}}

具体地,我们利用大根堆,每次从大根堆中取出等效权值最大的结点 p,与其父亲节点合并,让 p 的第一个点紧接在其父亲节点之后染色,并保存顺序,直到树合并为一个点,再顺序遍历统计代价。复杂度 O(n\log n)

完整代码:

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <queue>
using namespace std;
const int N = 1e3 + 5;
int n, r, fa[N], nxt[N], lst[N], num[N];
double c[N], d[N], tot[N];
bool vis[N];
typedef pair<double, int> PII;
int main(){
    while(scanf("%d%d", &n, &r) && n && r) {//多组数据
        priority_queue <PII, vector<PII>, less<PII> > q;//大根堆
        for (int i = 1; i <= n; i ++ ){//O(n)
            scanf("%lf", &c[i]);
            nxt[i] = lst[i] = i;
            num[i] = 1;
            tot[i] = c[i];
            if(i != r) q.push({c[i], i});//将非根节点入队
        }
        memcpy(d, c, sizeof d);
        for (int i = 1; i < n; i ++ ) {//O(n)
            int a, b;
            scanf("%d%d", &a, &b);
            fa[b] = a;
        }
        memset(vis, 0, sizeof vis);
        for (int i = 1; i < n; i ++ ){
            int p; 
            p = q.top().second; q.pop();
            while(vis[p] || p == r) {
                p = q.top().second;
                q.pop();
        }//取出等效权值最大且合法的结点
            int f = fa[p];
            while(vis[f]) fa[p] = f = fa[f];
            nxt[lst[f]] = p;
            lst[f] = lst[p];
            num[f] += num[p], tot[f] += tot[p];
            c[f] = tot[f] / num[f];//合并
            vis[p] = 1;
            q.push({c[f], f});//将更新的结果入队
        }
        int ans = 0;
        for (int i = 1; i <= n; i ++ ){
            ans += i * d[r];
            r = nxt[r];
        }//计算代价
        printf("%d\n", ans);
    }

    return 0;
}