P6071 『MdOI R1』Treequery - Solution

· · 题解

典题,感觉数据结构稍微多做点就能秒得很快了。

先钦定 1 为根。

显然 [l,\,r] 要么都在 p 子树内,要么都在 p 子树外。否则答案一定为 0

怎么算 $[l,\,r]$ 的公共 LCA?一种方法是计算 $\text{mindfn}$ 和 $\text{maxdfn}$ 的 LCA,一种是直接 $a_i \leftarrow \text{LCA}(i,\,i + 1)$ 然后算 $[l,\,r - 1]$ 的深度最值。我选择了后者。至于为什么后者是对的[详见这里](https://www.luogu.com.cn/article/a6f6tjgw)。 接下来考虑 $[l,\,r]$ 在 $p$ 子树外。 先找到 $p$ 向上走到的最深祖先 $x$,满足 $x$ 子树内存在任意一个编号在 $[l,\,r]$ 内的点。 两种情况: 1. $x$ 包含了所有 $[l,\,r]$ 的点,那么一样,求出公共 LCA,然后答案就是 LCA 到 $p$ 的距离。 2. $x$ 没有包含所有 $[l,\,r]$ 的点,那么答案一定是 $x$ 到 $p$ 的距离。 至此答案是什么就讨论完了,每次相当于倍增,判定 $u$ 子树内有多少个 $[l,\,r]$ 内的点。 持久化线段树合并即可。时间复杂度 $\Theta(n \log^2 n)$,空间复杂度 $\Theta(n \log n)$,可以通过。 写完代码直接 1A1C 了,舒服。 这个做法基本没有思维难度,代码难度也很低。 ```cpp #include <bits/stdc++.h> #define X first #define Y second #define rep(i, a, b) for (int i = a; i <= b; i++) #define per(i, a, b) for (int i = a; i >= b; i--) #define pb push_back #define mp make_pair #define mid (l + r >> 1) using namespace std; typedef long long int ll; using ull = unsigned long long int; using pii = pair<int, int>; constexpr int maxn = 2e5 + 10, mx = 2e6 + 5, mod = 998244353; struct edge { int to, nxt, w; } nd[maxn << 1]; int h[maxn], cnt = 0, n, Q, lg[maxn], ans; inline void add(int u, int v, int w) { nd[cnt].nxt = h[u], nd[cnt].to = v, nd[cnt].w = w, h[u] = cnt++; } struct Node { int l, r, v; } t[maxn * 41]; #define ls(x) (t[x].l) #define rs(x) (t[x].r) #define val(x) (t[x].v) #define mid (l + r >> 1) int fa[maxn][18], a[18][maxn], d[maxn], s[maxn], rt[maxn], tot = 0; void mg(int& x, int& y) { if (!x || !y) x |= y; else { int u = ++tot; t[u] = t[x]; x = u; val(u) += val(y); mg(ls(u), ls(y)), mg(rs(u), rs(y)); } } void mdf(int l, int r, int k, int p, int& x) { t[x = ++tot] = t[p]; ++val(x); if (l == r) return void(); return k <= mid ? mdf(l, mid, k, ls(p), ls(x)) : mdf(mid + 1, r, k, rs(p), rs(x)); } int qry(int l, int r, int ql, int qr, int x) { if (!x || ql <= l && r <= qr) return val(x); int sum = 0; (ql <= mid) && (sum += qry(l, mid, ql, qr, ls(x))), (qr > mid) && (sum += qry(mid + 1, r, ql, qr, rs(x))); return sum; } void dfs(int u, int f) { d[u] = d[fa[u][0] = f] + 1; rep(i, 1, 17) fa[u][i] = fa[fa[u][i - 1]][i - 1]; for (int i = h[u]; ~i; i = nd[i].nxt) { int v = nd[i].to; if (v != f) { s[v] = s[u] + nd[i].w, dfs(v, u); mg(rt[u], rt[v]); } } mdf(1, n, u, rt[u], rt[u]); } inline int lca(int u, int v) { if (d[u] < d[v]) swap(u, v); rep(i, 0, 17) if (d[u] - d[v] >> i & 1) u = fa[u][i]; if (u == v) return u; per(i, 17, 0) if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i]; return fa[u][0]; } inline int ck(int u, int v) { return d[u] < d[v] ? u : v; } inline int LCA(int l, int r) { if (l == r) return l; int g = lg[r-- - l]; return ck(a[g][l], a[g][r - (1 << g) + 1]); } #define F(l, r, x) qry(1, n, l, r, rt[x]) inline void solve(int p, int l, int r) { int W = 0; if ((W = F(l, r, p)) == r - l + 1) return void(ans = s[LCA(l, r)] - s[p]); else { if (W) return void(ans = 0); int x = p; per(i, 17, 0) if (fa[x][i] && !F(l, r, fa[x][i])) x = fa[x][i]; x = fa[x][0]; if (F(l, r, x) == r - l + 1) { int u = LCA(l, r); return void(ans = s[p] + s[u] - 2 * s[lca(p, u)]); } else return void(ans = s[p] - s[x]); } } int main() { memset(h, -1, sizeof(h)); scanf("%d%d", &n, &Q); for (int i = 2, u, v, w; i <= n; i++) scanf("%d%d%d", &u, &v, &w), add(u, v, w), add(v, u, w), lg[i] = lg[i >> 1] + 1; dfs(1, 0); rep(i, 1, n - 1) a[0][i] = lca(i, i + 1); rep(i, 1, 17) rep(j, 1, n - (1 << i - 1)) a[i][j] = ck(a[i - 1][j], a[i - 1][j + (1 << i - 1)]); //cout << F(8, 9, 4) << "\n"; while (Q--) { int p, l, r; scanf("%d%d%d", &p, &l, &r); p ^= ans, l ^= ans, r ^= ans; solve(p, l, r); printf("%d\n", ans); } return 0; } ```