题解:P3366 【模板】最小生成树

· · 题解

本文将讲解 Kruskal 以及 Prim 算法。

时间复杂度 O(n \log m)

Kruskal 的实现方式是并查集,思想是贪心。我们考虑优先对边权小的边加入最小生成树,然后最小生成树使用并查集维护。特别的,如果两点都已经在集合内,便跳过它,不连边。

众所周知点数为 n 的树,它的边数一定是 n-1,具体可以手玩一下。那么想要让最小生成树最小,肯定要构成一颗树,那么边数就是 n-1 了,即我们要在 m 条边中优先选择边权小的 n-1 条边(已经在最小生成树中的除外,并且顺延)。

证明:如果一个连通图属于最小生成树,那么从外部连接到连通图的最短边必然属于最小生成树,所以拆成若干个连通分量后,连通分量之间的最短边也必然属于最小生成树。

代码如下:

// Kruskal
// O(mlogm)
// 325ms
#include <bits/stdc++.h>
using namespace std;
struct edge {
    int u, v, w;
} e[200005];
int fa[5005], n, m, ans, cnt;
bool cmp(edge a, edge b) {
    return a.w < b.w;
}
int find(int x) {
    if(x != fa[x]) return fa[x] = find(fa[x]);
    else return x;
}
void kruskal() {
    sort(e + 1, e + m + 1, cmp);
    for(int i = 1; i <= m; i++) {
        int u = find(e[i].u);
        int v = find(e[i].v);
        if(u == v) continue;
        ans += e[i].w;
        fa[v] = u;
        cnt++;
        if(cnt > n - 1) break;
    }
}
int main() {
    cin >> n >> m;
    for(int i = 1; i <= n; i++) fa[i] = i;
    for(int i = 1; i <= m; i++) cin >> e[i].u >> e[i].v >> e[i].w;
    kruskal();
    if(cnt == n - 1) cout << ans;
    else puts("orz");
    return 0;
}

Prim 算法类似 Dijkstra 算法。Prim 算法的思想是以任意节点为根,找出所有与它相邻的所有边。再将新节点更新并以此节点作为根继续跑一遍。

证明:有任意点 v,那么对于 v 的最短相邻边,必然属于最小生成树。

代码如下:

// Prim
// O(mlogm)
// 327ms
#include <bits/stdc++.h>
#define MAXN 400500
using namespace std;
struct edge {
    int to, nxt, val;
} e[MAXN];
int h[MAXN], cnt, dis[MAXN], tot, ans;
bool vis[MAXN];
void add(int from, int to, int value) {
    e[++cnt].nxt = h[from];
    e[cnt].to = to;
    e[cnt].val = value;
    h[from] = cnt;
}
int n, m;
struct node{
    int pos, dis;
    friend bool operator < (node a, node b) {
        return a.dis > b.dis;
    }
} tmp;
priority_queue<node> q;
void prim() {
    for(int i = 1; i <= n; i++)
        dis[i] = 2147483647;
    dis[1] = 0;
    tmp.dis = 0;
    tmp.pos = 1;
    q.push(tmp);
    while(!q.empty()) {
        tmp = q.top();
        q.pop();
        int u = tmp.pos;
        int d = tmp.dis;
        if(vis[u]) continue;
        tot++;
        vis[u] = 1;
        ans += dis[u];
        for(int i = h[u]; i; i = e[i].nxt) {
            int v = e[i].to;
            int w = e[i].val;
            if(dis[v] > w) {
                tmp.dis = dis[v] = w;
                tmp.pos = v;
                q.push(tmp);
            }
        }
    }
}
int main() {
    cin >> n >> m;
    for(int i = 1; i <= m; i++) {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }
    prim();
    if(tot == n) cout << ans;
    else puts("orz");
    return 0;
}