最小差值生成树 题解

· · 题解

怎么题解里全是 LCT 而没有整体二分。

来写一发整体二分的题解。

对所有边先按边权排个序,我们记 f_i 为第 i 条边往后找能形成生成树的最小的位置,显然答案就是第 f_i 的边权和第 i 的边权差。

不难发现 f_i 单调不降,遂整体二分。

假设我们现在处理第 l 条边到第 r 条边,答案区间为 st,那么取中点 mid = \frac {l + r}{2} 去求 f _ {mid},假设我们已经求出,那么就可以对它两边分别去求解。

现在我们考虑怎么求出这个值,其实直接暴力就可以了,从 mid 开始往右找,暴力去判断是否合法,因为是整体二分还要用一个可撤销并查集。

大体思路就是这样的,我的代码里还用了一些优化的小技巧,应该不难理解。

code

#include <bits/stdc++.h>
using namespace std;

const int N = 1e6 + 10;

vector <int> raw[N];

struct edge {
    int u, v, w;
} e[N];

int n, m, len, ans = 1e9;

int a[N], f[N];

int top, cnt;

int fa[N], sz[N], s[N];

int find (int x) {
    return fa[x] == x ? x : find (fa[x]);
}

inline void merge (int x, int y) {
    x = find (x), y = find (y);
    if (x == y) return;
    if (sz[x] > sz[y]) swap (x, y);
    return fa[x] = y, s[++top] = x, sz[y] += sz[x], cnt++, void ();
}

inline void slipt (int time) {
    while (top > time) {
        int x = s[top--], y = fa[x];
        fa[x] = x, sz[y] -= sz[x], cnt--;
    }
}

#define u e[j].u
#define v e[j].v

void solve (int l, int r, int s, int t) {
    if (l > r or s > t) return;

    if (s == t) {
        for (int i = l; i <= r; i++)    
            f[i] = s;
        return;
    }

    int mid = (l + r) >> 1, ver = top;
    f[mid] = len + 1;

    for (int i = mid; i <= min (r, s - 1); i++)
        for (int j : raw[i]) merge (u, v);

    for (int i = max (s, mid); i <= t; i++) {
        for (int j : raw[i]) {
            merge (u, v);
            if (cnt == n - 1) {
                f[mid] = i;
                break;
            }
        }

        if (cnt == n - 1) break; 
    }

    slipt (ver);

    for (int i = mid; i <= min (r, s - 1); i++)
        for (int j : raw[i]) merge (u, v);
    solve (l, mid - 1, s, f[mid] - 1), slipt (ver);

    for (int i = max (r + 1, s); i < f[mid]; i++)
        for (int j : raw[i]) merge (u, v);
    solve (mid + 1, r, f[mid], t), slipt (ver);
}

#undef u
#undef v

int main () {
    ios :: sync_with_stdio (false), cin.tie (0), cout.tie (0);

    cin >> n >> m;
    for (int i = 1; i <= m; i++) cin >> e[i].u >> e[i].v >> e[i].w, a[i] = e[i].w;

    sort (a + 1, a + 1 + m), len = unique (a + 1, a + 1 + m) - a - 1;
    for (int i = 1; i <= m; i++) e[i].w = lower_bound (a + 1, a + 1 + len, e[i].w) - a, raw[e[i].w].emplace_back (i);

    for (int i = 1; i <= n; i++) fa[i] = i, sz[i] = 1;

    solve (1, len, 1, len);

    for (int i = 1; i <= len; i++) 
        if (f[i] and f[i] != len + 1) ans = min (ans, a[f[i]] - a[i]);
    cout << ans;

    return 0;
}