题解 P5633 【最小度限制生成树】

· · 题解

分享一个不用带权二分的做法吧。

为了方便描述,给所有非 s 的点一个点权,表示该点与 s 之间的所有边(虽然不知道数据造没造,但可能会有重边)的边权最小值(如果没边则点权设为无穷大)。则选中这个点相当于选它连向 s 的边。

题目转化成选中恰好 k 个点 + 删去若干条边,使得最终形成一个森林,且每棵树恰好选中一个点,最小化 “选中的点权值减去删边代价”。

接下来的讨论中默认不考虑 s

可以先对所有点求最小生成森林,则不在最小生成森林里的边一定会被删除(比较显然)。这样得到一个初始的森林。

对于初始的每棵树,首先必须选中一个点,可以证明贪心地选中点权最小的点最优。

f_x(i) 表示在连通分量 x 中选中 i 个点的最小答案,可以考虑贪心地使用增量法,从 f_x(i-1) 对应的答案加入一个新点使得 f_x(i) 最小。证明方法与上面那条贪心性质类似。

这里,设新点为 x,与 x 同属同一连通块且已经被选中的点为 y。则 x 的贡献等于 x 的点权减去 x\to y 这条路径上的最大边权。当然,加入新点 x 后要删去这条最大边。

注意到 f_x(i) 是凸的(也即该题带权二分的凸性质),那么我们可以求出所有 f_x(i)-f_x(i-1) 并从小到大排序,取出前若干个即可得到答案。

f_x(i)-f_x(i-1) 又可以看成 f_x(i-1)\to f_x(i) 中新加进来的点的贡献,于是只需要尝试求出每个点的贡献,也即每个点对应需要删掉的边即可。

考虑到具体实现,不妨去考虑每条边会被哪个点删掉。

从大到小考虑每条边,设最大边所连接的两个连通块中分别的最小值为 x,y,则最大边被删掉的时刻对应的点应该是 \max\{x, y\}(也比较显然)。

之后删掉最大边,再考虑次大边,以此类推。

发现这个过程反过来和 kruskal 基本一样,因此在 kruskal 的时候顺便维护一下即可。

无解的情况:至少需要的边太多,或者最多能连的边太少(注意点权为无穷大的点不能被选中)。

时间复杂度 O(M\log M),瓶颈在于 kruskal 的排序。

#include <bits/stdc++.h>

typedef long long ll;

const int M = 500000;
const int INF = (1 << 30);

struct edge{
    int u, v, w;
    friend bool operator < (const edge &a, const edge &b) {
        return a.w < b.w;
    }
}e[M + 5]; int cnt;

int fa[M + 5], val[M + 5], key[M + 5];
int find(int x) {return (fa[x] == x ? x : fa[x] = find(fa[x]));}
bool unite(int x, int y, int k) {
    int fx = find(x), fy = find(y);
    if( fx != fy ) {
        if( val[fx] < val[fy] ) std::swap(fx, fy);
        key[fx] = k, fa[fx] = fy; return true;
    } else return false;
}

int tmp[M + 5];
int main() {
    int n, m, s, k; scanf("%d%d%d%d", &n, &m, &s, &k);
    for(int i=1;i<=n;i++) if( i != s ) fa[i] = i, val[i] = INF;
    for(int i=1,u,v,w;i<=m;i++) {
        scanf("%d%d%d", &u, &v, &w);
        if( u == s ) val[v] = std::min(val[v], w);
        else if( v == s ) val[u] = std::min(val[u], w);
        else e[++cnt] = (edge){u, v, w};
    }

    ll ans = 0; std::sort(e + 1, e + cnt + 1);
    for(int i=1;i<=cnt;i++) if( unite(e[i].u, e[i].v, e[i].w) ) ans += e[i].w;

    int p = 0;
    for(int i=1;i<=n;i++) if( i != s && find(i) == i ) {
        if( val[i] == INF ) {puts("Impossible"); return 0;}
        p++, ans += val[i], val[i] = INF;
    }
    if( p > k ) {puts("Impossible"); return 0;}

    int tot = 0;
    for(int i=1;i<=n;i++) if( i != s && val[i] != INF ) tmp[++tot] = val[i] - key[i];
    if( p + tot < k ) {puts("Impossible"); return 0;}

    std::sort(tmp + 1, tmp + tot + 1); for(int i=1;i<=k-p;i++) ans += tmp[i];
    printf("%lld\n", ans);
}