题解 SP3978 【DISQUERY - Distance Query】

· · 题解

解题思路

这道题有比较多的做法,树链剖分 + 线段树,LCT 等都能做,但前者有两个 \log,后者常数较大,因此这里给出一个 \log 跑得飞快的树链剖分 + ST 表的做法。

首先,化边权为点权,这个只需要从任意结点开始 dfs 一遍,把边权转到所连的子结点就行了。

之后,维护固定树上的路径,树链剖分当然绰绰有余,现在问题是如何求一条重链上权值最小值和最大值。因为没有修改,显然这是 ST 表的特长,用类似线段树的方法,在以时间戳为下标的 ST 表中搞。

接下来就是两个模板结合的事儿了,注意细节。

代码实现

#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <iostream>

inline int read() {
    char c = getchar(); int n = 0; bool f = false;
    while (c < '0' || c > '9') { if (c == '-') { f = true; } c = getchar(); }
    while (c >= '0' && c <= '9') { n = (n << 1) + (n << 3) + (c & 15); c = getchar(); }
    return f ? -n : n;
}

const int maxN = 200005, maxE = maxN << 1, maxLog = 23;

int n, m, t, logN, fth[maxN], dep[maxN], siz[maxN], son[maxN], val[maxN], dfn[maxN], top[maxN], min[maxLog][maxN], max[maxLog][maxN], g[maxN];

struct List {
    int len, fst[maxN], nxt[maxE], v[maxE], w[maxE];

    List() { memset(fst, -1, sizeof(fst)); }
    inline void insert(int u, int vv, int ww) { v[len] = vv, w[len] = ww, nxt[len] = fst[u], fst[u] = len++; }
    inline void connect(int u, int vv, int ww) { insert(u, vv, ww), insert(vv, u, ww); }
} ls;

void dfs1(int u, int fa) {
    fth[u] = fa; dep[u] = dep[fa] + 1; siz[u] = 1;
    for (int i = ls.fst[u]; ~i; i = ls.nxt[i]) {
        int v = ls.v[i], w = ls.w[i];
        if (v != fa) {
            val[v] = w; dfs1(v, u); siz[u] += siz[v];
            if (siz[son[u]] < siz[v]) { son[u] = v; }
        }
    }
}
void dfs2(int u, int p) {
    top[u] = p; dfn[u] = ++t; min[0][t] = max[0][t] = val[u];
    if (son[u]) { dfs2(son[u], p); }
    for (int i = ls.fst[u]; ~i; i = ls.nxt[i]) {
        int v = ls.v[i];
        if (v != fth[u] && v != son[u]) { dfs2(v, v); }   
    }
}

void buildST() {
    for (int i = 1; i <= logN; i++) {
        for (int j = 1; j + (1 << i) - 1 <= n; j++) { min[i][j] = std::min(min[i - 1][j], min[i - 1][j + (1 << i - 1)]); max[i][j] = std::max(max[i - 1][j], max[i - 1][j + (1 << i - 1)]); }
    }
}
inline void queryST(int l, int r, int &resMin, int &resMax) {
    if (l > r) { return; } int logLen = g[r - l + 1];
    resMin = std::min(resMin, std::min(min[logLen][l], min[logLen][r - (1 << logLen) + 1]));
    resMax = std::max(resMax, std::max(max[logLen][l], max[logLen][r - (1 << logLen) + 1]));
}

inline void query(int u, int v) {
    int ansMin = 1e9, ansMax = -1e9;
    while (top[u] != top[v]) {
        if (dep[top[u]] >= dep[top[v]]) { queryST(dfn[top[u]], dfn[u], ansMin, ansMax); u = fth[top[u]]; }
        else { queryST(dfn[top[v]], dfn[v], ansMin, ansMax); v = fth[top[v]]; }
    }
    if (dep[u] < dep[v]) { queryST(dfn[u] + 1, dfn[v], ansMin, ansMax); } else { queryST(dfn[v] + 1, dfn[u], ansMin, ansMax); }
    printf("%d %d\n", ansMin, ansMax);
}

int main() {
    n = read(); while (1 << logN + 1 <= n) { logN++; }
    for (int i = 1; i <= n; i++) { g[i] = log(i) / log(2.0); }
    for (int i = 2; i <= n; i++) { int u = read(), v = read(); ls.connect(u, v, read()); }
    dfs1(1, 0); dfs2(1, 1); buildST();
    for (m = read(); m; m--) { query(read(), read()); }
    return 0;
}