题解:P10658 BZOJ2959 长跑

· · 题解

预备知识

思路

由于一个点可以多次经过,但是每一条边方向必须是固定的,所以可以自然的想到边双连通分量

每到达一个边双连通分量,可以在这个边双连通分量中转圈。这样同一个边双连通分量之中的所有点一定都可以到达。

因此我们可以直接缩点,缩点后的权值为所有点本来值的总和。缩完点的无向图是森林。

然后这就变成了一道用动态树维护双连通分量的题目。缩点我们可以使用并查集。

具体实现

由于我们使用了并查集,所以在标准动态树的基础上,还需要进行略微改动。

具体来说:

  1. access 函数中改动一下。每次向上移动前,将父亲节点设为父亲节点所在并查集的根。

  2. 新增函数 dfs 来实现对于整条实链缩点这一功能。

然后再来说题目的三种询问操作:

  1. 对于操作 1:先用 find_root 函数判断 A 和 B 是否连通。如果不连通直接将它们连起来;否则将这两个点之间的实链分离出来,整条实链缩点。

  2. 对于操作 2:先令 C 为 A 节点所在的并查集的根,把这个 C make_root(C) 成整棵树的根,最后更新 C 的权值。为了更新 C 的权值,需要记录每个节点目前的权值,通过变化的差值进行更新。

  3. 对于操作 3:先判断是否连通,如果连通就分离 A 和 B 之间的实链,输出实链权值和。否则输出 -1。

如果将并查集复杂度当作常数,那么总复杂度为 O(q \log_2 n)

代码


// Author: chenly8128
// Created: 2025-03-21 20:57:47

#include <bits/stdc++.h>
using namespace std;
inline int read (void) {
    int res = 0;bool flag = true;char c = getchar();
    while (c < '0' || c > '9') {flag ^= (c == '-');c = getchar();}
    while (c >= '0' && c <= '9') {res = (res << 3) + (res << 1) + (c ^ 48);c = getchar();}
    return flag ? res : -res;
}
const int MAXN = 2e5+10;
int fa[MAXN],cnt[MAXN];
int find (int x) {
    return fa[x] = fa[x] == x ? x : find(fa[x]);
}
void merge (int a,int b) {
    a = find(a);b = find(b);
    if (a == b) return;
    fa[a] = b;
    cnt[b] += cnt[a];
}
struct LCT {
    int sum[MAXN],ch[MAXN][2],fa[MAXN],tot;
    bool lazy[MAXN];
#define ls(x) ch[x][0]
#define rs(x) ch[x][1]
    inline int dir (int fa,int x) {return ls(fa) == x ? 0 : 1;}
    inline bool isroot (int x) {return ls(fa[x]) != x && rs(fa[x]) != x;}
    inline void push_up (int x) {sum[x] = sum[ls(x)] + sum[rs(x)] + cnt[x];}
    inline void reverse (int x) {if (x) {swap(ls(x),rs(x));lazy[x] ^= true;}}
    inline void push_down (int x) {
        if (lazy[x]) {
            reverse(ls(x));
            reverse(rs(x));
            lazy[x] = false;
        }
    }
    void push (int x) {
        if (!isroot(x)) push(find(fa[x]));
        push_down(x);
    }
    void rotate (int x) {
        if (isroot(x)) return;
        int y = fa[x]; int z = fa[y],r = dir(y,x);
        ch[y][r] = ch[x][r^1];
        ch[x][r^1] = y;
        if (!isroot(y)) ch[z][dir(z,y)] = x;
        if (ch[y][r]) fa[ch[y][r]] = y;
        fa[x] = z;
        fa[y] = x;
        push_up(y);
        push_up(x);
    }
    void splay (int x) {
        push(x);
        int y,z;
        while (!isroot(x)) {
            y = fa[x];z = fa[y];
            if (!isroot(y)) rotate(dir(z,y) == dir(y,x) ? y : x);
            rotate(x);
        }
        push_up(x);
    }
    void access (int y) {
        for (int x = 0;y;x = y,y = fa[x] = find(fa[x])) {
            splay(y);
            rs(y) = x;
            push_up(y);
        }
    }
    void make_root (int x) {
        access(x);splay(x);reverse(x);
    }
    int find_root (int x) {
        access(x);splay(x);
        while (ls(x)) {
            push_down(x);
            x = ls(x);
        }
        splay(x);
        return x;
    }
    void split(int x,int y) {
        make_root(x);
        access(y);
        splay(y);
    }
    void link (int x,int y) {
        make_root(x);make_root(y);
        fa[x] = y;
    }
    void dfs (int x,int rt) {
        merge(x,rt);
        if (ls(x)) dfs(ls(x),rt);
        if (rs(x)) dfs(rs(x),rt);
        ls(x) = rs(x) = 0;
    }
} a;

int n,m,tmp[MAXN];
int op,x,y;
int main (void) {
    n = read(); m = read();
    for (int i = 1;i <= n;i++) {
        tmp[i] = cnt[i] = read();
        fa[i] = i;
    }
    while (m--) {
        op = read(); x = read(); y = read();
        if (op == 2) {
            int p = find(x);
            a.make_root(p);
            cnt[p] += y-tmp[x];
            tmp[x] = y;
            a.push_up(p);
        }
        else {
            x = find(x); y = find(y);
            if (op == 1) {
                if (a.find_root(x) != a.find_root(y)) a.link(x,y);
                else {
                    a.split(x,y);
                    a.dfs(y,y);
                    a.push_up(y);
                }
            }
            else {
                if (a.find_root(x) != a.find_root(y)) printf ("-1\n");
                else {
                    a.split(x,y);
                    printf ("%d\n",a.sum[y]);
                }
            }
        }
    }
    return 0;
}