P8336 [Ynoi2004] 2stmst
EuphoricStar · · 题解
设
完全图 MST 容易想到 Boruvka,问题转化为求一端为
分
- 没有任何限制,直接取
\min 。 -
-
-
-
-
-
-
-
时间复杂度
代码看起来很长,但是很多内容都是重复的。
// Problem: P8336 [Ynoi2004] 2stmst
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P8336
// Memory Limit: 512 MB
// Time Limit: 6000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;
namespace IO {
const int maxn = 1 << 20;
char ibuf[maxn], *iS, *iT, obuf[maxn], *oS = obuf;
inline char gc() {
return (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, maxn, stdin), (iS == iT ? EOF : *iS++) : *iS++);
}
template<typename T = int>
inline T read() {
char c = gc();
T x = 0;
bool f = 0;
while (c < '0' || c > '9') {
f |= (c == '-');
c = gc();
}
while (c >= '0' && c <= '9') {
x = (x << 1) + (x << 3) + (c ^ 48);
c = gc();
}
return f ? ~(x - 1) : x;
}
inline void flush() {
fwrite(obuf, 1, oS - obuf, stdout);
oS = obuf;
}
struct Flusher {
~Flusher() {
flush();
}
} AutoFlush;
inline void pc(char ch) {
if (oS == obuf + maxn) {
flush();
}
*oS++ = ch;
}
template<typename T>
inline void write(T x) {
static char stk[64], *tp = stk;
if (x < 0) {
x = ~(x - 1);
pc('-');
}
do {
*tp++ = x % 10;
x /= 10;
} while (x);
while (tp != stk) {
pc((*--tp) | 48);
}
}
template<typename T>
inline void writesp(T x) {
write(x);
pc(' ');
}
template<typename T>
inline void writeln(T x) {
write(x);
pc('\n');
}
}
using IO::read;
using IO::write;
using IO::pc;
using IO::writesp;
using IO::writeln;
const int maxn = 1000100;
const int inf = 0x3f3f3f3f;
int n, m, fa[maxn], pa[maxn];
struct que {
int x, y;
} a[maxn];
struct graph {
int hd[maxn], to[maxn], nxt[maxn], len;
inline void add_edge(int u, int v) {
to[++len] = v;
nxt[len] = hd[u];
hd[u] = len;
}
} G;
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
inline bool merge(int x, int y) {
x = find(x);
y = find(y);
if (x != y) {
fa[x] = y;
return 1;
} else {
return 0;
}
}
int st[maxn], ed[maxn], tim, rnk[maxn], sz[maxn];
int tot, in[maxn], out[maxn], ord[maxn << 1];
void dfs(int u) {
st[u] = ++tim;
in[u] = ++tot;
ord[tot] = u;
sz[u] = 1;
rnk[tim] = u;
for (int i = G.hd[u]; i; i = G.nxt[i]) {
int v = G.to[i];
dfs(v);
sz[u] += sz[v];
}
ed[u] = tim;
out[u] = ++tot;
ord[tot] = u;
}
struct node {
int x1, f1, x2, f2;
node(int a = 0, int b = 0, int c = 0, int d = 0) : x1(a), f1(b), x2(c), f2(d) {}
} c[maxn];
pii b[maxn];
inline node operator + (node a, node b) {
if (a.x1 > b.x1) {
swap(a, b);
}
node res = a;
if (b.x1 < res.x2 && b.f1 != a.f1) {
res.x2 = b.x1;
res.f2 = b.f1;
} else if (b.x2 < res.x2 && b.f2 != a.f1) {
res.x2 = b.x2;
res.f2 = b.f2;
}
return res;
}
struct List {
int hd[maxn], to[maxn], nxt[maxn], len;
inline void add(int x, int y) {
to[++len] = y;
nxt[len] = hd[x];
hd[x] = len;
}
} L1, L2;
int rt[maxn];
struct SGT1 {
int nt, ls[maxn * 3], rs[maxn * 3];
node a[maxn * 3];
inline void init() {
for (int i = 0; i <= nt; ++i) {
ls[i] = rs[i] = 0;
a[i] = node();
}
a[0] = node(inf, 0, inf, 0);
nt = 0;
}
void update(int &rt, int l, int r, int x, const node &y) {
if (!rt) {
rt = ++nt;
a[rt] = node(inf, 0, inf, 0);
}
a[rt] = a[rt] + y;
if (l == r) {
return;
}
int mid = (l + r) >> 1;
(x <= mid) ? update(ls[rt], l, mid, x, y) : update(rs[rt], mid + 1, r, x, y);
}
void query(int rt, int l, int r, int ql, int qr, node &res) {
if (!rt) {
return;
}
if (ql <= l && r <= qr) {
res = res + a[rt];
return;
}
int mid = (l + r) >> 1;
if (ql <= mid) {
query(ls[rt], l, mid, ql, qr, res);
}
if (qr > mid) {
query(rs[rt], mid + 1, r, ql, qr, res);
}
}
int merge(int u, int v, int l, int r) {
if (!u || !v) {
return u | v;
}
if (l == r) {
a[u] = a[u] + a[v];
return u;
}
int mid = (l + r) >> 1;
ls[u] = merge(ls[u], ls[v], l, mid);
rs[u] = merge(rs[u], rs[v], mid + 1, r);
a[u] = a[ls[u]] + a[rs[u]];
return u;
}
} T1;
pair<node*, node> stk[maxn * 3];
int top, tp[maxn];
struct SGT2 {
node a[maxn * 3];
int N;
inline void init() {
N = 1;
while (N < n + 2) {
N <<= 1;
}
for (int i = 1; i <= N + n; ++i) {
a[i] = node(inf, 0, inf, 0);
}
}
inline void update(int x, node y) {
x += N;
while (x) {
stk[++top] = mkp(a + x, a[x]);
a[x] = a[x] + y;
x >>= 1;
}
}
inline node query(int l, int r) {
node res(inf, 0, inf, 0);
for (l += N - 1, r += N + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
if (!(l & 1)) {
res = res + a[l ^ 1];
}
if (r & 1) {
res = res + a[r ^ 1];
}
}
return res;
}
} T2;
struct SGT3 {
node a[maxn << 2];
void build(int rt, int l, int r) {
a[rt] = node(inf, 0, inf, 0);
if (l == r) {
return;
}
int mid = (l + r) >> 1;
build(rt << 1, l, mid);
build(rt << 1 | 1, mid + 1, r);
}
void update(int rt, int l, int r, int ql, int qr, const node &x) {
if (ql <= l && r <= qr) {
stk[++top] = mkp(a + rt, a[rt]);
a[rt] = a[rt] + x;
return;
}
int mid = (l + r) >> 1;
if (ql <= mid) {
update(rt << 1, l, mid, ql, qr, x);
}
if (qr > mid) {
update(rt << 1 | 1, mid + 1, r, ql, qr, x);
}
}
void query(int rt, int l, int r, int x, node &res) {
res = res + a[rt];
if (l == r) {
return;
}
int mid = (l + r) >> 1;
(x <= mid) ? query(rt << 1, l, mid, x, res) : query(rt << 1 | 1, mid + 1, r, x, res);
}
} T3;
void solve() {
n = read();
m = read();
for (int i = 2; i <= n; ++i) {
pa[i] = read();
G.add_edge(pa[i], i);
}
for (int i = 1; i <= m; ++i) {
a[i].x = read();
a[i].y = read();
fa[i] = i;
L1.add(a[i].x, i);
L2.add(a[i].y, i);
}
dfs(1);
ll ans = 0;
while (1) {
bool fl = 1;
for (int i = 1; i <= m; ++i) {
fl &= (find(i) == find(1));
b[i] = mkp(inf, 0);
}
if (fl) {
break;
}
node p(inf, 0, inf, 0);
for (int i = 1; i <= m; ++i) {
p = p + node(sz[a[i].x] + sz[a[i].y], fa[i], inf, 0);
}
for (int i = 1; i <= m; ++i) {
if (p.f1 != fa[i]) {
b[fa[i]] = min(b[fa[i]], mkp(p.x1 + sz[a[i].x] + sz[a[i].y], p.f1));
} else {
b[fa[i]] = min(b[fa[i]], mkp(p.x2 + sz[a[i].x] + sz[a[i].y], p.f2));
}
}
for (int i = 1; i <= n; ++i) {
int u = rnk[i];
c[u] = node(inf, 0, inf, 0);
if (u > 1) {
c[u] = c[pa[u]];
}
for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
int j = L1.to[_];
c[u] = c[u] + node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0);
}
for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
int j = L1.to[_];
if (c[u].f1 != fa[j]) {
b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].y] - sz[u], c[u].f1));
} else {
b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].y] - sz[u], c[u].f2));
}
}
}
for (int i = 1; i <= n; ++i) {
int u = rnk[i];
c[u] = node(inf, 0, inf, 0);
if (u > 1) {
c[u] = c[pa[u]];
}
for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
int j = L2.to[_];
c[u] = c[u] + node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0);
}
for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
int j = L2.to[_];
if (c[u].f1 != fa[j]) {
b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] - sz[u], c[u].f1));
} else {
b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] - sz[u], c[u].f2));
}
}
}
for (int i = n; i; --i) {
int u = rnk[i];
c[u] = node(inf, 0, inf, 0);
for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
int j = L1.to[_];
c[u] = c[u] + node(sz[a[j].y] - sz[a[j].x], fa[j], inf, 0);
}
for (int _ = G.hd[u]; _; _ = G.nxt[_]) {
int v = G.to[_];
c[u] = c[u] + c[v];
}
for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
int j = L1.to[_];
if (c[u].f1 != fa[j]) {
b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] + sz[a[j].y], c[u].f1));
} else {
b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] + sz[a[j].y], c[u].f2));
}
}
}
T1.init();
for (int i = n; i; --i) {
int u = rnk[i];
c[u] = node(inf, 0, inf, 0);
rt[u] = 0;
for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
int j = L2.to[_];
c[u] = c[u] + node(sz[a[j].x] - sz[u], fa[j], inf, 0);
T1.update(rt[u], 1, n, st[a[j].x], node(-sz[a[j].x] - sz[a[j].y], fa[j], inf, 0));
}
for (int _ = G.hd[u]; _; _ = G.nxt[_]) {
int v = G.to[_];
c[u] = c[u] + c[v];
rt[u] = T1.merge(rt[u], rt[v], 1, n);
}
for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
int j = L2.to[_];
if (c[u].f1 != fa[j]) {
b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] + sz[a[j].y], c[u].f1));
} else {
b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] + sz[a[j].y], c[u].f2));
}
node res(inf, 0, inf, 0);
T1.query(rt[u], 1, n, st[a[j].x], ed[a[j].x], res);
if (res.f1 != fa[j]) {
b[fa[j]] = min(b[fa[j]], mkp(res.x1 + sz[a[j].x] + sz[a[j].y], res.f1));
} else {
b[fa[j]] = min(b[fa[j]], mkp(res.x2 + sz[a[j].x] + sz[a[j].y], res.f2));
}
}
}
T2.init();
top = 0;
for (int i = 1; i <= tot; ++i) {
int u = ord[i];
if (in[u] == i) {
tp[u] = top;
for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
int j = L1.to[_];
T2.update(st[a[j].y], node(sz[a[j].x] - sz[a[j].y], fa[j], inf, 0));
}
for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
int j = L1.to[_];
node res = T2.query(st[a[j].y], ed[a[j].y]);
if (res.f1 != fa[j]) {
b[fa[j]] = min(b[fa[j]], mkp(res.x1 - sz[a[j].x] + sz[a[j].y], res.f1));
} else {
b[fa[j]] = min(b[fa[j]], mkp(res.x2 - sz[a[j].x] + sz[a[j].y], res.f2));
}
}
} else {
while (top > tp[u]) {
*stk[top].fst = stk[top].scd;
--top;
}
}
}
T2.init();
top = 0;
for (int i = 1; i <= tot; ++i) {
int u = ord[i];
if (in[u] == i) {
tp[u] = top;
for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
int j = L2.to[_];
T2.update(st[a[j].x], node(sz[a[j].y] - sz[a[j].x], fa[j], inf, 0));
}
for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
int j = L2.to[_];
node res = T2.query(st[a[j].x], ed[a[j].x]);
if (res.f1 != fa[j]) {
b[fa[j]] = min(b[fa[j]], mkp(res.x1 + sz[a[j].x] - sz[a[j].y], res.f1));
} else {
b[fa[j]] = min(b[fa[j]], mkp(res.x2 + sz[a[j].x] - sz[a[j].y], res.f2));
}
}
} else {
while (top > tp[u]) {
*stk[top].fst = stk[top].scd;
--top;
}
}
}
T3.build(1, 1, n);
top = 0;
for (int i = 1; i <= tot; ++i) {
int u = ord[i];
if (in[u] == i) {
tp[u] = top;
for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
int j = L1.to[_];
T3.update(1, 1, n, st[a[j].y], ed[a[j].y], node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0));
}
for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
int j = L1.to[_];
node res(inf, 0, inf, 0);
T3.query(1, 1, n, st[a[j].y], res);
if (res.f1 != fa[j]) {
b[fa[j]] = min(b[fa[j]], mkp(res.x1 - sz[a[j].x] - sz[a[j].y], res.f1));
} else {
b[fa[j]] = min(b[fa[j]], mkp(res.x2 - sz[a[j].x] - sz[a[j].y], res.f2));
}
}
} else {
while (top > tp[u]) {
*stk[top].fst = stk[top].scd;
--top;
}
}
}
for (int i = 1; i <= m; ++i) {
if (fa[i] == i && merge(i, b[i].scd)) {
ans += b[i].fst;
}
}
}
writeln(ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}