题解 P7735【[NOI2021] 轻重边】

· · 题解

提供一种魔改树链剖分的做法。

分析

把题中的重边看成边权 1,轻边看成边权 0。边权显然不太好处理,所以把每条边的边权变成儿子节点的点权。这时,一次修改相当于把路径上所有的点以及它们的儿子赋值为 0,再把路径上 LCA 外的点赋值为 1。一次询问就是问路径上 LCA 外的点的权值和。

容易想到树链剖分。但传统的树链剖分的编号原则是「重链编号连续、子树编号连续」,我们这里不需要维护子树信息,需要知道儿子的信息,所以比较自然想到,用某种方式使得「重链上所有点的儿子编号连续」。

如何实现呢?我的方法是,先按传统的方法把重链都分出来,然后类似于 BFS,建立一个队列,把以根节点为 top 的重链塞进去。每次取出队头的重链,把这条重链上的点的儿子依次编号,然后把以这些儿子为 top 的重链塞进队列。

此外,对每个点 x,再预处理出其所在的重链中,由 x 向上/向下走能碰到的节点的儿子的编号最大/小值,就可以容易调用线段树的操作来实现儿子的赋值。

上述方案唯一的美中不足是,一条重链的 toptop 的重儿子编号可能不连续,所以对于链的操作,每次只能取出 top 的重儿子往下的一段进行操作,再单独操作 top 自己。这会导致常数比较大,大概是传统树剖的 4-5 倍。

显然,与传统树剖的时间复杂度一样,为 O(n \log^2 n)

考场上跑自己脚造的数据会超 0.2s 左右,感觉常数巨大没希望了,但 CCF 的巨牛逼机子最大的点只跑了 0.8s,超赞。

代码

具体实现的时候可以不把链先赋值为 0,以减小常数。

考场上的核心代码,稍微修得好看了一点:

struct SegT {
    int l, r, sum, tg;
    #define ls (p << 1)
    #define rs (p << 1 | 1)
} t[N<<2];
void build(int p, int l, int r) {
    t[p].l = l, t[p].r = r, t[p].sum = 0, t[p].tg = -1;
    if (l == r) return;
    int mid = (l + r) >> 1;
    build(ls, l, mid), build(rs, mid + 1, r);
}
inline void push_up(int p) { t[p].sum = t[ls].sum + t[rs].sum; }
inline void push_down(int p) {
    if (t[p].tg != -1) {
        t[ls].sum = t[p].tg * (t[ls].r - t[ls].l + 1), t[ls].tg = t[p].tg;
        t[rs].sum = t[p].tg * (t[rs].r - t[rs].l + 1), t[rs].tg = t[p].tg;
        t[p].tg = -1;
    }
}
void upd(int p, int l, int r, int v) {
    if (l <= t[p].l && t[p].r <= r) return t[p].sum = v * (t[p].r - t[p].l + 1), t[p].tg = v, void();
    int mid = (t[p].l + t[p].r) >> 1;
    push_down(p);
    if (l <= mid) upd(ls, l, r, v);
    if (r > mid) upd(rs, l, r, v);
    push_up(p);
}
int qry(int p, int l, int r) {
    if (l <= t[p].l && t[p].r <= r) return t[p].sum;
    int mid = (t[p].l + t[p].r) >> 1, res = 0;
    push_down(p);
    if (l <= mid) res += qry(ls, l, r);
    if (r > mid) res += qry(rs, l, r);
    return res;
}

int fa[N], siz[N], son[N], dep[N], tp[N];
int dfn[N], rk[N], sec;
int upmx[N], upmn[N], dwmx[N], dwmn[N];
vector<int> V[N];
int q[N], hd, tl;

void dfs1(int x, int ff, int dpt) {
    fa[x] = ff, dep[x] = dpt;
    siz[x] = 1, son[x] = 0;
    for (int i = g[x]; i; i = nxt[i]) {
        int y = v[i];
        if (y == fa[x]) continue;
        dfs1(y, x, dep[x] + 1);
        if (!son[x] || siz[y] > siz[son[x]]) son[x] = y;
        siz[x] += siz[y];
    }
}

void dfs2(int x) {
    if (!tp[x]) tp[x] = x;
    V[tp[x]].pb(x);
    for (int i = g[x]; i; i = nxt[i]) {
        int y = v[i];
        if (y == fa[x]) continue;
        if (y == son[x]) tp[y] = tp[x];
        dfs2(y);
    }
}

inline void upd0(int x, int y) {
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
        if (son[x]) upd(1, dfn[son[x]], dfn[son[x]], 0);
        if (dwmn[tp[x]] <= upmx[x]) upd(1, dwmn[tp[x]], upmx[x], 0);
        x = fa[tp[x]];
    }
    if (dep[x] < dep[y]) swap(x, y);
    if (son[x]) upd(1, dfn[son[x]], dfn[son[x]], 0);
    if (dwmn[y] <= upmx[x]) upd(1, dwmn[y], upmx[x], 0);
    upd(1, dfn[y], dfn[y], 0);
}

inline void upd1(int x, int y) {
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
        if (tp[x] != x) upd(1, dfn[son[tp[x]]], dfn[x], 1);
        upd(1, dfn[tp[x]], dfn[tp[x]], 1);
        x = fa[tp[x]];
    }
    if (dep[x] < dep[y]) swap(x, y);
    if (x != y) upd(1, dfn[son[y]], dfn[x], 1);
}

inline int qry(int x, int y) {
    int res = 0;
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
        if (tp[x] != x) res += qry(1, dfn[son[tp[x]]], dfn[x]);
        res += qry(1, dfn[tp[x]], dfn[tp[x]]);
        x = fa[tp[x]];
    }
    if (dep[x] < dep[y]) swap(x, y);
    if (x != y) res += qry(1, dfn[son[y]], dfn[x]);
    return res;
}

inline void init() {
    memset(g, 0, sizeof(g)), tot = 0;
    memset(tp, 0, sizeof(tp)), sec = 0;
    for (int i = 1; i <= n; i++) V[i].clear();
}

inline void solve() {
    n = rd(), m = rd();
    init();
    for (int i = 1; i < n; i++) {
        int x = rd(), y = rd();
        add(x, y), add(y, x);
    }
    for (int i = 1; i <= m; i++) {
        qs[i].op = rd(), qs[i].x = rd(), qs[i].y = rd();
    }
    if (n <= 5000 && m <= 5000) {
        sub1::main();
        return;
    }
    dfs1(1, 0, 0);
    dfs2(1);
    hd = tl = 0;
    q[tl++] = 1, dfn[1] = ++sec, rk[sec] = 1;
    while (hd != tl) {
        int x = q[hd++], len = V[x].size();
        for (int i = 1; i < len; i++) {
            int y = V[x][i];
            dfn[y] = ++sec, rk[sec] = y;
        }
        for (int i = 0; i < len; i++) {
            int y = V[x][i];
            upmx[y] = -inf, upmn[y] = inf;
            for (int j = g[y]; j; j = nxt[j]) {
                int z = v[j];
                if (z == fa[y]) continue;
                if (z == son[y]) continue;
                dfn[z] = ++sec, rk[sec] = z;
                q[tl++] = z;
                chkmax(upmx[y], dfn[z]);
                chkmin(upmn[y], dfn[z]);
            }
            dwmx[y] = upmx[y], dwmn[y] = upmn[y];
        }
        for (int i = 1; i < len; i++) {
            chkmax(upmx[V[x][i]], upmx[V[x][i-1]]);
            chkmin(upmn[V[x][i]], upmn[V[x][i-1]]);
        }
        for (int i = len - 2; i >= 0; i--) {
            chkmax(dwmx[V[x][i]], dwmx[V[x][i+1]]);
            chkmin(dwmn[V[x][i]], dwmn[V[x][i+1]]);
        }
    }
    build(1, 1, n);
    for (int i = 1; i <= m; i++) {
        int op = qs[i].op, x = qs[i].x, y = qs[i].y;
        if (op == 1) upd0(x, y), upd1(x, y);
        else printf("%d\n", qry(x, y));
    }
}