P4180 [BJWC2010] 严格次小生成树
P4180 [BJWC2010] 严格次小生成树
算法见其他题解,这篇题解主要是传播代码实现方式。
代码实现
#include <bits/stdc++.h>
using namespace std;
#define dbg(...) cerr << "[" << #__VA_ARGS__ << "] = ", debug_out(__VA_ARGS__)
template <typename T> void debug_out(T t) { cerr << t << endl; }
template <typename T, typename... Ts> void debug_out(T t, Ts... ts) {
cerr << t << ", ";
debug_out(ts...);
}
template <class F>
struct y_combinator {
F f;
template <class... Args>
decltype(auto) operator()(Args&&... args) {
return f(*this, std::forward<Args>(args)...);
}
};
template <class F>
auto make_y_combinator(F f) { return y_combinator<F>{f}; }
int main() {
cin.tie(0)->sync_with_stdio(0);
int n, m;
cin >> n >> m;
vector<tuple<int, int, int, bool>> edge;
for (int i = 0; i < m; i++) {
int u, v, w;
cin >> u >> v >> w;
if (u == v) continue;
edge.emplace_back(w, u, v, 0);
}
sort(edge.begin(), edge.end());
vector<int> ufs(n + 1);
iota(ufs.begin(), ufs.end(), 0);
auto get = make_y_combinator([&](auto self, int x) -> int {
return ufs[x] == x ? x : ufs[x] = self(ufs[x]);
});
using ll = long long;
ll sum = 0;
for (auto &[w, u, v, t] : edge) {
int x = get(u), y = get(v);
if (x != y) ufs[x] = y, sum += w, t = 1;
}
// dbg(sum);
vector<vector<tuple<int, int>>> e(n + 1);
for (auto &[w, u, v, t] : edge)
if (t) e[u].emplace_back(v, w), e[v].emplace_back(u, w);
vector<int> d(n + 1);
int w = 31 - __builtin_clz(n - 1);
vector<vector<int>> f(w, vector<int>(n + 1));
vector<vector<tuple<int, int>>> g(w, vector<tuple<int, int>>(n + 1));
auto dfs = make_y_combinator([&](auto self, int u) -> void {
for (auto &[v, w] : e[u]) {
if (v == f[0][u]) continue;
d[v] = d[u] + 1;
f[0][v] = u;
g[0][v] = {w, -1};
self(v);
}
});
dfs(1);
auto merge = [&](tuple<int, int> a, tuple<int, int> b) {
auto [x1, y1] = a;
auto [x2, y2] = b;
if (x1 > x2) return make_tuple(x1, max(y1, x2));
if (x1 < x2) return make_tuple(x2, max(y2, x1));
return make_tuple(x1, max(y1, y2));
};
for (int i = 1; i < w; i++)
for (int u = 1; u <= n; u++) {
f[i][u] = f[i - 1][f[i - 1][u]];
g[i][u] = merge(g[i - 1][u], g[i - 1][f[i - 1][u]]);
}
auto ask = [&](int u, int v) {
tuple<int, int> res = {-1, -1};
if (d[u] < d[v]) swap(u, v);
for (int i = w - 1; i >= 0; i--)
if (d[f[i][u]] >= d[v]) {
res = merge(res, g[i][u]);
u = f[i][u];
}
if (u == v) return res;
for (int i = w - 1; i >= 0; i--)
if (f[i][u] != f[i][v]) {
res = merge(res, g[i][u]);
res = merge(res, g[i][v]);
u = f[i][u], v = f[i][v];
}
res = merge(res, g[0][u]);
res = merge(res, g[0][v]);
return res;
};
ll ans = 1e18;
for (auto &[w, u, v, t] : edge) {
if (t) continue;
auto [x, y] = ask(u, v);
// dbg(w, u, v, x, y);
if (w > x)
ans = min(ans, sum + w - x);
else if (w == x && y != -1)
ans = min(ans, sum + w - y);
}
cout << ans << endl;
return 0;
}
细节说明
支持可变参数的调试宏 dbg
#define dbg(...) cerr << "[" << #__VA_ARGS__ << "] = ", debug_out(__VA_ARGS__)
template <typename T> void debug_out(T t) { cerr << t << endl; }
template <typename T, typename... Ts> void debug_out(T t, Ts... ts) {
cerr << t << ", ";
debug_out(ts...);
}
Y Combinator 实现递归 Lambda
template <class F>
struct y_combinator {
F f;
template <class... Args>
decltype(auto) operator()(Args&&... args) {
return f(*this, std::forward<Args>(args)...);
}
};
template <class F>
auto make_y_combinator(F f) { return y_combinator<F>{f}; }
这是一个不动点组合子的实现,允许我们写出递归的 lambda 表达式:
auto get = make_y_combinator([&](auto self, int x) -> int {
return ufs[x] == x ? x : ufs[x] = self(ufs[x]);
});
调用时类似普通函数:
int x = get(u), y = get(v);
结构化绑定
for (auto &[w, u, v, t] : edge) {
int x = get(u), y = get(v);
if (x != y) ufs[x] = y, sum += w, t = 1;
}
最大值和次大值的合并
使用 tuple<int, int> 存储路径上的最大值和次大值:
auto merge = [&](tuple<int, int> a, tuple<int, int> b) {
auto [x1, y1] = a;
auto [x2, y2] = b;
if (x1 > x2) return make_tuple(x1, max(y1, x2));
if (x1 < x2) return make_tuple(x2, max(y2, x1));
return make_tuple(x1, max(y1, y2));
};
这个 merge 函数优雅地合并两个区间的最大值和次大值。
并查集的初始化方式
vector<int> ufs(n + 1);
iota(ufs.begin(), ufs.end(), 0);
使用 iota 函数初始化并查集数组,比循环更简洁。
内置函数的使用
int w = 31 - __builtin_clz(n - 1);
使用 GCC 内置函数 __builtin_clz 计算前导零个数,从而得到倍增数组大小。
倍增数组维度顺序
vector<vector<int>> f(w, vector<int>(n + 1));
vector<vector<tuple<int, int>>> g(w, vector<tuple<int, int>>(n + 1));