题解:P1001 A+B Problem
liuzhongrui · · 题解
思路:
动态树板题,这里使用比较强的 SATT 解决。
设两个节点,一个点权为
代码:
#include <iostream>
#define ls(x) T[x][0]
#define rs(x) T[x][1]
#define ms(x) T[x][2]
using namespace std;
const int MAXN = 300005;
int T[MAXN][3], s[MAXN][2], tot, v[MAXN], r[MAXN], top, st[MAXN], f[MAXN];
int nnd() {
if (top) {
top--;
return st[top + 1];
}
return ++tot;
}
bool isr(int x) {
return rs(f[x]) != x && ls(f[x]) != x;
}
bool dir(int x) {
return rs(f[x]) == x;
}
void psu(int x, int ty) {
if (ty) {
s[x][1] = s[ls(x)][1] + s[rs(x)][1] + s[ms(x)][1];
return;
}
s[x][0] = s[ls(x)][0] + v[x] + s[rs(x)][0];
s[x][1] = s[ls(x)][1] + s[ms(x)][1] + s[rs(x)][1] + v[x];
}
void psr(int x) {
if (!x) return;
r[x] ^= 1;
swap(ls(x), rs(x));
}
void psd(int x, int ty) {
if (ty) return;
if (r[x]) {
psr(ls(x));
psr(rs(x));
r[x] = 0;
}
}
void upd(int x, int ty) {
if (!isr(x)) upd(f[x], ty);
psd(x, ty);
}
void stf(int x, int fa, int ty) {
if (x) f[x] = fa;
T[fa][ty] = x;
}
void rtt(int x, int ty) {
int y = f[x], z = f[y], d = dir(x), w = T[x][d ^ 1];
if (z) T[z][ms(z) == y ? 2 : dir(y)] = x;
T[x][d ^ 1] = y;
T[y][d] = w;
if (w) f[w] = y;
f[y] = x;
f[x] = z;
psu(y, ty);
psu(x, ty);
}
void spy(int x, int ty, int gl = 0) {
upd(x, ty);
for (int y; y = f[x], (!isr(x)) && y != gl; rtt(x, ty)) {
if (f[y] != gl && (!isr(y))) rtt(dir(x) ^ dir(y) ? x : y, ty);
}
}
void cle(int x) {
ls(x) = ms(x) = rs(x) = s[x][0] = s[x][1] = r[x] = v[x] = 0;
st[++top] = x;
}
void del(int x) {
stf(ms(x), f[x], 1);
if (ls(x)) {
int p = ls(x);
psd(p, 1);
while (rs(p)) p = rs(p), psd(p, 1);
spy(p, 1, x);
stf(rs(x), p, 1);
stf(p, f[x], 2);
psu(p, 1);
psu(f[x], 0);
} else
stf(rs(x), f[x], 2);
cle(x);
}
void spl(int x) {
spy(x, 1);
int y = f[x];
spy(y, 0);
psd(x, 1);
if (rs(y)) {
swap(f[ms(x)], f[rs(y)]);
swap(ms(x), rs(y));
psu(x, 1);
} else
del(x);
psu(rs(y), 0);
psu(y, 0);
}
void acs(int x) {
spy(x, 0);
int ys = x;
if (rs(x)) {
int y = nnd();
stf(ms(x), y, 0);
stf(rs(x), y, 2);
rs(x) = 0;
stf(y, x, 2);
psu(y, 1);
psu(x, 0);
}
while (f[x]) {
spl(f[x]);
x = f[x];
}
spy(ys, 0);
}
int fdr(int x) {
acs(x);
psd(x, 0);
while (ls(x)) x = ls(x), psd(x, 0);
spy(x, 0);
return x;
}
void mkr(int x) {
acs(x);
psr(x);
}
void epo(int x, int y) {
mkr(x);
acs(y);
}
void lnk(int x, int y) {
if (fdr(x) == fdr(y)) return;
acs(x);
mkr(y);
stf(y, x, 1);
psu(x, 0);
psu(y, 0);
}
void cu(int x, int y) {
epo(x, y);
if (ls(y) != x || rs(x)) return;
f[x] = ls(y) = 0;
psu(y, 0);
}
int main() {
tot = 2;
cin >> v[1] >> v[2];
psu(1, 0);
psu(2, 0);
lnk(1, 2);
epo(1, 2);
cout << s[2][0];
return 0;
}