题解:P1001 A+B Problem

· · 题解

思路:

动态树板题,这里使用比较强的 SATT 解决。

设两个节点,一个点权为 A,一个点权为 B,则路径权值和即为 A+B

代码:

#include <iostream>
#define ls(x) T[x][0]
#define rs(x) T[x][1]
#define ms(x) T[x][2]
using namespace std;
const int MAXN = 300005;
int T[MAXN][3], s[MAXN][2], tot, v[MAXN], r[MAXN], top, st[MAXN], f[MAXN];
int nnd() {
    if (top) {
        top--;
        return st[top + 1];
    }
    return ++tot;
}
bool isr(int x) {
    return rs(f[x]) != x && ls(f[x]) != x;
}
bool dir(int x) {
    return rs(f[x]) == x;
}
void psu(int x, int ty) {
    if (ty) {
        s[x][1] = s[ls(x)][1] + s[rs(x)][1] + s[ms(x)][1];
        return;
    }
    s[x][0] = s[ls(x)][0] + v[x] + s[rs(x)][0];
    s[x][1] = s[ls(x)][1] + s[ms(x)][1] + s[rs(x)][1] + v[x];
}
void psr(int x) {
    if (!x) return;
    r[x] ^= 1;
    swap(ls(x), rs(x));
}
void psd(int x, int ty) {
    if (ty) return;
    if (r[x]) {
        psr(ls(x));
        psr(rs(x));
        r[x] = 0;
    }
}
void upd(int x, int ty) {
    if (!isr(x)) upd(f[x], ty);
    psd(x, ty);
}
void stf(int x, int fa, int ty) {
    if (x) f[x] = fa;
    T[fa][ty] = x;
}
void rtt(int x, int ty) {
    int y = f[x], z = f[y], d = dir(x), w = T[x][d ^ 1];
    if (z) T[z][ms(z) == y ? 2 : dir(y)] = x;
    T[x][d ^ 1] = y;
    T[y][d] = w;
    if (w) f[w] = y;
    f[y] = x;
    f[x] = z;
    psu(y, ty);
    psu(x, ty);
}
void spy(int x, int ty, int gl = 0) {
    upd(x, ty);
    for (int y; y = f[x], (!isr(x)) && y != gl; rtt(x, ty)) {
        if (f[y] != gl && (!isr(y))) rtt(dir(x) ^ dir(y) ? x : y, ty);
    }
}
void cle(int x) {
    ls(x) = ms(x) = rs(x) = s[x][0] = s[x][1] = r[x] = v[x] = 0;
    st[++top] = x;
}
void del(int x) {
    stf(ms(x), f[x], 1);
    if (ls(x)) {
        int p = ls(x);
        psd(p, 1);
        while (rs(p)) p = rs(p), psd(p, 1);
        spy(p, 1, x);
        stf(rs(x), p, 1);
        stf(p, f[x], 2);
        psu(p, 1);
        psu(f[x], 0);
    } else
        stf(rs(x), f[x], 2);
    cle(x);
}
void spl(int x) {
    spy(x, 1);
    int y = f[x];
    spy(y, 0);
    psd(x, 1);
    if (rs(y)) {
        swap(f[ms(x)], f[rs(y)]);
        swap(ms(x), rs(y));
        psu(x, 1);
    } else
        del(x);
    psu(rs(y), 0);
    psu(y, 0);
}
void acs(int x) {
    spy(x, 0);
    int ys = x;
    if (rs(x)) {
        int y = nnd();
        stf(ms(x), y, 0);
        stf(rs(x), y, 2);
        rs(x) = 0;
        stf(y, x, 2);
        psu(y, 1);
        psu(x, 0);
    }
    while (f[x]) {
        spl(f[x]);
        x = f[x];
    }
    spy(ys, 0);
}
int fdr(int x) {
    acs(x);
    psd(x, 0);
    while (ls(x)) x = ls(x), psd(x, 0);
    spy(x, 0);
    return x;
}
void mkr(int x) {
    acs(x);
    psr(x);
}
void epo(int x, int y) {
    mkr(x);
    acs(y);
}
void lnk(int x, int y) {
    if (fdr(x) == fdr(y)) return;
    acs(x);
    mkr(y);
    stf(y, x, 1);
    psu(x, 0);
    psu(y, 0);
}
void cu(int x, int y) {
    epo(x, y);
    if (ls(y) != x || rs(x)) return;
    f[x] = ls(y) = 0;
    psu(y, 0);
}
int main() {
    tot = 2;
    cin >> v[1] >> v[2];
    psu(1, 0);
    psu(2, 0);
    lnk(1, 2);
    epo(1, 2);
    cout << s[2][0];
    return 0;
}