P8336 [Ynoi2004] 2stmst

· · 题解

F(x, y) = D(x, y) + D(y, x),那么当 xy 祖先时,F(x, y) = sz_x - sz_y;当 yx 祖先时,F(x, y) = sz_y - sz_x;当 x, y 不互为祖孙关系时,F(x, y) = sz_x + sz_y。图上 i, j 的边权即为 F(x_i, x_j) + F(y_i, y_j)

完全图 MST 容易想到 Boruvka,问题转化为求一端为 i 且另一端与 i 不在同一个连通块的边权最小值。然后是 Boruvka 的经典套路,考虑直接求出一个信息:一端为 i 的边权最小值和次小值,钦定这两条边的另一个端点不在同一个连通块。这个信息是可以合并的,所以可以当成没有不在同一个连通块的限制然后做。

9 种情况讨论:

时间复杂度 O(n \log m + m \log n \log m)。实现时可以预处理出 dfs 序,就不用每次都 dfs 了。

代码看起来很长,但是很多内容都是重复的。

// Problem: P8336 [Ynoi2004] 2stmst
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P8336
// Memory Limit: 512 MB
// Time Limit: 6000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;

namespace IO {
    const int maxn = 1 << 20;

    char ibuf[maxn], *iS, *iT, obuf[maxn], *oS = obuf;

    inline char gc() {
        return (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, maxn, stdin), (iS == iT ? EOF : *iS++) : *iS++);
    }

    template<typename T = int>
    inline T read() {
        char c = gc();
        T x = 0;
        bool f = 0;
        while (c < '0' || c > '9') {
            f |= (c == '-');
            c = gc();
        }
        while (c >= '0' && c <= '9') {
            x = (x << 1) + (x << 3) + (c ^ 48);
            c = gc();
        }
        return f ? ~(x - 1) : x;
    }

    inline void flush() {
        fwrite(obuf, 1, oS - obuf, stdout);
        oS = obuf;
    }

    struct Flusher {
        ~Flusher() {
            flush();
        }
    } AutoFlush;

    inline void pc(char ch) {
        if (oS == obuf + maxn) {
            flush();
        }
        *oS++ = ch;
    }

    template<typename T>
    inline void write(T x) {
        static char stk[64], *tp = stk;
        if (x < 0) {
            x = ~(x - 1);
            pc('-');
        }
        do {
            *tp++ = x % 10;
            x /= 10;
        } while (x);
        while (tp != stk) {
            pc((*--tp) | 48);
        }
    }

    template<typename T>
    inline void writesp(T x) {
        write(x);
        pc(' ');
    }

    template<typename T>
    inline void writeln(T x) {
        write(x);
        pc('\n');
    }
}

using IO::read;
using IO::write;
using IO::pc;
using IO::writesp;
using IO::writeln;

const int maxn = 1000100;
const int inf = 0x3f3f3f3f;

int n, m, fa[maxn], pa[maxn];

struct que {
    int x, y;
} a[maxn];

struct graph {
    int hd[maxn], to[maxn], nxt[maxn], len;

    inline void add_edge(int u, int v) {
        to[++len] = v;
        nxt[len] = hd[u];
        hd[u] = len;
    }
} G;

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

inline bool merge(int x, int y) {
    x = find(x);
    y = find(y);
    if (x != y) {
        fa[x] = y;
        return 1;
    } else {
        return 0;
    }
}

int st[maxn], ed[maxn], tim, rnk[maxn], sz[maxn];
int tot, in[maxn], out[maxn], ord[maxn << 1];

void dfs(int u) {
    st[u] = ++tim;
    in[u] = ++tot;
    ord[tot] = u;
    sz[u] = 1;
    rnk[tim] = u;
    for (int i = G.hd[u]; i; i = G.nxt[i]) {
        int v = G.to[i];
        dfs(v);
        sz[u] += sz[v];
    }
    ed[u] = tim;
    out[u] = ++tot;
    ord[tot] = u;
}

struct node {
    int x1, f1, x2, f2;
    node(int a = 0, int b = 0, int c = 0, int d = 0) : x1(a), f1(b), x2(c), f2(d) {}
} c[maxn];

pii b[maxn];

inline node operator + (node a, node b) {
    if (a.x1 > b.x1) {
        swap(a, b);
    }
    node res = a;
    if (b.x1 < res.x2 && b.f1 != a.f1) {
        res.x2 = b.x1;
        res.f2 = b.f1;
    } else if (b.x2 < res.x2 && b.f2 != a.f1) {
        res.x2 = b.x2;
        res.f2 = b.f2;
    }
    return res;
}

struct List {
    int hd[maxn], to[maxn], nxt[maxn], len;

    inline void add(int x, int y) {
        to[++len] = y;
        nxt[len] = hd[x];
        hd[x] = len;
    }
} L1, L2;

int rt[maxn];

struct SGT1 {
    int nt, ls[maxn * 3], rs[maxn * 3];
    node a[maxn * 3];

    inline void init() {
        for (int i = 0; i <= nt; ++i) {
            ls[i] = rs[i] = 0;
            a[i] = node();
        }
        a[0] = node(inf, 0, inf, 0);
        nt = 0;
    }

    void update(int &rt, int l, int r, int x, const node &y) {
        if (!rt) {
            rt = ++nt;
            a[rt] = node(inf, 0, inf, 0);
        }
        a[rt] = a[rt] + y;
        if (l == r) {
            return;
        }
        int mid = (l + r) >> 1;
        (x <= mid) ? update(ls[rt], l, mid, x, y) : update(rs[rt], mid + 1, r, x, y);
    }

    void query(int rt, int l, int r, int ql, int qr, node &res) {
        if (!rt) {
            return;
        }
        if (ql <= l && r <= qr) {
            res = res + a[rt];
            return;
        }
        int mid = (l + r) >> 1;
        if (ql <= mid) {
            query(ls[rt], l, mid, ql, qr, res);
        }
        if (qr > mid) {
            query(rs[rt], mid + 1, r, ql, qr, res);
        }
    }

    int merge(int u, int v, int l, int r) {
        if (!u || !v) {
            return u | v;
        }
        if (l == r) {
            a[u] = a[u] + a[v];
            return u;
        }
        int mid = (l + r) >> 1;
        ls[u] = merge(ls[u], ls[v], l, mid);
        rs[u] = merge(rs[u], rs[v], mid + 1, r);
        a[u] = a[ls[u]] + a[rs[u]];
        return u;
    }
} T1;

pair<node*, node> stk[maxn * 3];
int top, tp[maxn];

struct SGT2 {
    node a[maxn * 3];
    int N;

    inline void init() {
        N = 1;
        while (N < n + 2) {
            N <<= 1;
        }
        for (int i = 1; i <= N + n; ++i) {
            a[i] = node(inf, 0, inf, 0);
        }
    }

    inline void update(int x, node y) {
        x += N;
        while (x) {
            stk[++top] = mkp(a + x, a[x]);
            a[x] = a[x] + y;
            x >>= 1;
        }
    }

    inline node query(int l, int r) {
        node res(inf, 0, inf, 0);
        for (l += N - 1, r += N + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
            if (!(l & 1)) {
                res = res + a[l ^ 1];
            }
            if (r & 1) {
                res = res + a[r ^ 1];
            }
        }
        return res;
    }
} T2;

struct SGT3 {
    node a[maxn << 2];

    void build(int rt, int l, int r) {
        a[rt] = node(inf, 0, inf, 0);
        if (l == r) {
            return;
        }
        int mid = (l + r) >> 1;
        build(rt << 1, l, mid);
        build(rt << 1 | 1, mid + 1, r);
    }

    void update(int rt, int l, int r, int ql, int qr, const node &x) {
        if (ql <= l && r <= qr) {
            stk[++top] = mkp(a + rt, a[rt]);
            a[rt] = a[rt] + x;
            return;
        }
        int mid = (l + r) >> 1;
        if (ql <= mid) {
            update(rt << 1, l, mid, ql, qr, x);
        }
        if (qr > mid) {
            update(rt << 1 | 1, mid + 1, r, ql, qr, x);
        }
    }

    void query(int rt, int l, int r, int x, node &res) {
        res = res + a[rt];
        if (l == r) {
            return;
        }
        int mid = (l + r) >> 1;
        (x <= mid) ? query(rt << 1, l, mid, x, res) : query(rt << 1 | 1, mid + 1, r, x, res);
    }
} T3;

void solve() {
    n = read();
    m = read();
    for (int i = 2; i <= n; ++i) {
        pa[i] = read();
        G.add_edge(pa[i], i);
    }
    for (int i = 1; i <= m; ++i) {
        a[i].x = read();
        a[i].y = read();
        fa[i] = i;
        L1.add(a[i].x, i);
        L2.add(a[i].y, i);
    }
    dfs(1);
    ll ans = 0;
    while (1) {
        bool fl = 1;
        for (int i = 1; i <= m; ++i) {
            fl &= (find(i) == find(1));
            b[i] = mkp(inf, 0);
        }
        if (fl) {
            break;
        }
        node p(inf, 0, inf, 0);
        for (int i = 1; i <= m; ++i) {
            p = p + node(sz[a[i].x] + sz[a[i].y], fa[i], inf, 0);
        }
        for (int i = 1; i <= m; ++i) {
            if (p.f1 != fa[i]) {
                b[fa[i]] = min(b[fa[i]], mkp(p.x1 + sz[a[i].x] + sz[a[i].y], p.f1));
            } else {
                b[fa[i]] = min(b[fa[i]], mkp(p.x2 + sz[a[i].x] + sz[a[i].y], p.f2));
            }
        }
        for (int i = 1; i <= n; ++i) {
            int u = rnk[i];
            c[u] = node(inf, 0, inf, 0);
            if (u > 1) {
                c[u] = c[pa[u]];
            }
            for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
                int j = L1.to[_];
                c[u] = c[u] + node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0);
            }
            for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
                int j = L1.to[_];
                if (c[u].f1 != fa[j]) {
                    b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].y] - sz[u], c[u].f1));
                } else {
                    b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].y] - sz[u], c[u].f2));
                }
            }
        }
        for (int i = 1; i <= n; ++i) {
            int u = rnk[i];
            c[u] = node(inf, 0, inf, 0);
            if (u > 1) {
                c[u] = c[pa[u]];
            }
            for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
                int j = L2.to[_];
                c[u] = c[u] + node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0);
            }
            for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
                int j = L2.to[_];
                if (c[u].f1 != fa[j]) {
                    b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] - sz[u], c[u].f1));
                } else {
                    b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] - sz[u], c[u].f2));
                }
            }
        }
        for (int i = n; i; --i) {
            int u = rnk[i];
            c[u] = node(inf, 0, inf, 0);
            for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
                int j = L1.to[_];
                c[u] = c[u] + node(sz[a[j].y] - sz[a[j].x], fa[j], inf, 0);
            }
            for (int _ = G.hd[u]; _; _ = G.nxt[_]) {
                int v = G.to[_];
                c[u] = c[u] + c[v];
            }
            for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
                int j = L1.to[_];
                if (c[u].f1 != fa[j]) {
                    b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] + sz[a[j].y], c[u].f1));
                } else {
                    b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] + sz[a[j].y], c[u].f2));
                }
            }
        }
        T1.init();
        for (int i = n; i; --i) {
            int u = rnk[i];
            c[u] = node(inf, 0, inf, 0);
            rt[u] = 0;
            for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
                int j = L2.to[_];
                c[u] = c[u] + node(sz[a[j].x] - sz[u], fa[j], inf, 0);
                T1.update(rt[u], 1, n, st[a[j].x], node(-sz[a[j].x] - sz[a[j].y], fa[j], inf, 0));
            }
            for (int _ = G.hd[u]; _; _ = G.nxt[_]) {
                int v = G.to[_];
                c[u] = c[u] + c[v];
                rt[u] = T1.merge(rt[u], rt[v], 1, n);
            }
            for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
                int j = L2.to[_];
                if (c[u].f1 != fa[j]) {
                    b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] + sz[a[j].y], c[u].f1));
                } else {
                    b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] + sz[a[j].y], c[u].f2));
                }
                node res(inf, 0, inf, 0);
                T1.query(rt[u], 1, n, st[a[j].x], ed[a[j].x], res);
                if (res.f1 != fa[j]) {
                    b[fa[j]] = min(b[fa[j]], mkp(res.x1 + sz[a[j].x] + sz[a[j].y], res.f1));
                } else {
                    b[fa[j]] = min(b[fa[j]], mkp(res.x2 + sz[a[j].x] + sz[a[j].y], res.f2));
                }
            }
        }
        T2.init();
        top = 0;
        for (int i = 1; i <= tot; ++i) {
            int u = ord[i];
            if (in[u] == i) {
                tp[u] = top;
                for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
                    int j = L1.to[_];
                    T2.update(st[a[j].y], node(sz[a[j].x] - sz[a[j].y], fa[j], inf, 0));
                }
                for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
                    int j = L1.to[_];
                    node res = T2.query(st[a[j].y], ed[a[j].y]);
                    if (res.f1 != fa[j]) {
                        b[fa[j]] = min(b[fa[j]], mkp(res.x1 - sz[a[j].x] + sz[a[j].y], res.f1));
                    } else {
                        b[fa[j]] = min(b[fa[j]], mkp(res.x2 - sz[a[j].x] + sz[a[j].y], res.f2));
                    }
                }
            } else {
                while (top > tp[u]) {
                    *stk[top].fst = stk[top].scd;
                    --top;
                }
            }
        }
        T2.init();
        top = 0;
        for (int i = 1; i <= tot; ++i) {
            int u = ord[i];
            if (in[u] == i) {
                tp[u] = top;
                for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
                    int j = L2.to[_];
                    T2.update(st[a[j].x], node(sz[a[j].y] - sz[a[j].x], fa[j], inf, 0));
                }
                for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
                    int j = L2.to[_];
                    node res = T2.query(st[a[j].x], ed[a[j].x]);
                    if (res.f1 != fa[j]) {
                        b[fa[j]] = min(b[fa[j]], mkp(res.x1 + sz[a[j].x] - sz[a[j].y], res.f1));
                    } else {
                        b[fa[j]] = min(b[fa[j]], mkp(res.x2 + sz[a[j].x] - sz[a[j].y], res.f2));
                    }
                }
            } else {
                while (top > tp[u]) {
                    *stk[top].fst = stk[top].scd;
                    --top;
                }
            }
        }
        T3.build(1, 1, n);
        top = 0;
        for (int i = 1; i <= tot; ++i) {
            int u = ord[i];
            if (in[u] == i) {
                tp[u] = top;
                for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
                    int j = L1.to[_];
                    T3.update(1, 1, n, st[a[j].y], ed[a[j].y], node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0));
                }
                for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
                    int j = L1.to[_];
                    node res(inf, 0, inf, 0);
                    T3.query(1, 1, n, st[a[j].y], res);
                    if (res.f1 != fa[j]) {
                        b[fa[j]] = min(b[fa[j]], mkp(res.x1 - sz[a[j].x] - sz[a[j].y], res.f1));
                    } else {
                        b[fa[j]] = min(b[fa[j]], mkp(res.x2 - sz[a[j].x] - sz[a[j].y], res.f2));
                    }
                }
            } else {
                while (top > tp[u]) {
                    *stk[top].fst = stk[top].scd;
                    --top;
                }
            }
        }
        for (int i = 1; i <= m; ++i) {
            if (fa[i] == i && merge(i, b[i].scd)) {
                ans += b[i].fst;
            }
        }
    }
    writeln(ans);
}

int main() {
    int T = 1;
    // scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}