题解:P6362 平面欧几里得最小生成树

· · 题解

注意到是完全图 MST,又不想学神秘科技,所以考虑 Boruvka 算法。

那么现在就需要做到对于一个在 c 集合里的点 u,找出不在 c 集合里距离 u 最近的点 v,也就是近邻。

这个东西可以考虑使用 KDT 解决。

注意,这里不能使用找出最小次小的方法来获取非 c 集合的点 v,这样复杂度是完全错的。

所以每次对于一个集合 c,直接删掉 c 集合中的所有点,然后对于每个点找近邻即可。

注意到 Boruvka 对于一个集合只需要找出一条连到别的集合最短的边,所以对于同一个集合,直接延用一个 ans,在 KDT 上遍历的时候看能不能使 ans 变得更优即可。

接下来是对于复杂度的一些讨论。

首先 Boruvka 遍历 O(\log n) 次,KDT 上删除添加是 O(\log n),所以打底 O(n \log^2 n)

然后是 KDT 找近邻的部分,但是有人说是 O(n^2) 的,也有人说是 O(n\sqrt{n}) 的。

但是这玩意很难卡,其中一个原因是因为题目中的值域很小且都是整数。还有一点就是 Boruvka 会使点合并,即使第一轮卡很满,后面依然会跑很快,所以不是很能卡掉,再加个随机旋转就基本卡不掉了。

由于邪恶的ケロシ加了神秘数据来卡 KDT 常数,所以要进行一些神秘卡常,然后直接 c++98 提交即可。

const int N = 1e5 + 5;
const int INF = 1e8 + 7;
const ll LNF = 1e18;
int n, cmp;
struct Point {
    double x[2];
    bool operator < (const Point A) const {
        return x[cmp] < A.x[cmp];
    }
} a[N];
struct KDT {
    double L[2], R[2];
    int ls, rs, p, b;
} t[N]; int rt, tot;
struct Node {
    double w; int c;
    bool operator > (Node & A) const {
        if(w == A.w) return c > A.c;
        return w > A.w;
    }
} p[N];
int f[N];
vector<int> e[N];
inline double sq(double x) {
    return x * x;
}
inline double dist(Point A, Point B) {
    return sq(A.x[0] - B.x[0]) + sq(A.x[1] - B.x[1]);
}
inline int find(int u) {
    if(f[u] == u) return u;
    return f[u] = find(f[u]);
}
inline void pushup(int u) {
    int l = t[u].ls, r = t[u].rs;
    REP(i, 2) {
        if(t[u].b) t[u].L[i] = t[u].R[i] = a[t[u].p].x[i];
        else t[u].L[i] = INF, t[u].R[i] = - INF;
        if(l) {
            chmin(t[u].L[i], t[l].L[i]);
            chmax(t[u].R[i], t[l].R[i]);
        }
        if(r) {
            chmin(t[u].L[i], t[r].L[i]);
            chmax(t[u].R[i], t[r].R[i]);
        }
    }
}
int build(int l, int r, int o) {
    if(l > r) return 0;
    int mid = l + r >> 1;
    int u = ++ tot;
    cmp = o;
    nth_element(a + l, a + mid, a + r + 1);
    t[u].b = 1;
    t[u].p = mid;
    t[u].ls = build(l, mid - 1, o ^ 1);
    t[u].rs = build(mid + 1, r, o ^ 1);
    pushup(u);
    return u;
}
void query(int u, int v) {
    if(! u) return;
    int c = find(v);
    if(t[u].b) chmin(p[c], (Node){dist(a[v], a[t[u].p]), find(t[u].p)});
    int o0 = t[u].L[0] <= a[v].x[0] && a[v].x[0] <= t[u].R[0];
    int o1 = t[u].L[1] <= a[v].x[1] && a[v].x[1] <= t[u].R[1];
    if(! o0 && ! o1) {
        double val = LNF;
        chmin(val, dist(a[v], {t[u].L[0], t[u].L[1]}));
        chmin(val, dist(a[v], {t[u].L[0], t[u].R[1]}));
        chmin(val, dist(a[v], {t[u].R[0], t[u].L[1]}));
        chmin(val, dist(a[v], {t[u].R[0], t[u].R[1]}));
        if(val >= p[c].w) return;
    }
    if(o0 && ! o1) {
        double val = LNF;
        chmin(val, sq(a[v].x[1] - t[u].L[1]));
        chmin(val, sq(a[v].x[1] - t[u].R[1]));
        if(val >= p[c].w) return;
    }
    if(! o0 && o1) {
        double val = LNF;
        chmin(val, sq(a[v].x[0] - t[u].L[0]));
        chmin(val, sq(a[v].x[0] - t[u].R[0]));
        if(val >= p[c].w) return;
    }
    query(t[u].ls, v);
    query(t[u].rs, v);
}
void insert(int u, int l, int r, int q, int x) {
    int mid = l + r >> 1;
    if(mid == q) {
        t[u].b = x;
        pushup(u);
        return;
    }
    if(q < mid) insert(t[u].ls, l, mid - 1, q, x);
    else insert(t[u].rs, mid + 1, r, q, x);
    pushup(u);
}
void solve() {
    cin >> n;
    double alpha = 1.14;
    FOR(i, 1, n) {
        double X, Y;
        cin >> X >> Y;
        a[i].x[0] = X * cos(alpha) - Y * sin(alpha);
        a[i].x[1] = X * sin(alpha) + Y * cos(alpha);
    }
    rt = build(1, n, 0);
    FOR(i, 1, n) f[i] = i;
    int cnt = n;
    double ans = 0;
    while(cnt > 1) {
        FOR(i, 1, n) p[i] = {LNF, 0};
        FOR(i, 1, n) e[i].clear();
        FOR(i, 1, n) e[find(i)].push_back(i);
        FOR(i, 1, n) if(! e[i].empty()) {
            FORV(x, e[i]) insert(rt, 1, n, * x, 0);
            FORV(x, e[i]) query(rt, * x);
            FORV(x, e[i]) insert(rt, 1, n, * x, 1);
        }
        FOR(i, 1, n) if(p[i].c) {
            int u = find(i);
            int v = find(p[i].c);
            if(u == v) continue;
            f[v] = u;
            cnt --;
            ans += sqrt(p[i].w);
        }
    }
    cout << fixed << setprecision(10);
    cout << ans << endl;
}