P6071 『MdOI R1』Treequery - Solution
strcmp
·
·
题解
典题,感觉数据结构稍微多做点就能秒得很快了。
先钦定 1 为根。
显然 [l,\,r] 要么都在 p 子树内,要么都在 p 子树外。否则答案一定为 0。
怎么算 $[l,\,r]$ 的公共 LCA?一种方法是计算 $\text{mindfn}$ 和 $\text{maxdfn}$ 的 LCA,一种是直接 $a_i \leftarrow \text{LCA}(i,\,i + 1)$ 然后算 $[l,\,r - 1]$ 的深度最值。我选择了后者。至于为什么后者是对的[详见这里](https://www.luogu.com.cn/article/a6f6tjgw)。
接下来考虑 $[l,\,r]$ 在 $p$ 子树外。
先找到 $p$ 向上走到的最深祖先 $x$,满足 $x$ 子树内存在任意一个编号在 $[l,\,r]$ 内的点。
两种情况:
1. $x$ 包含了所有 $[l,\,r]$ 的点,那么一样,求出公共 LCA,然后答案就是 LCA 到 $p$ 的距离。
2. $x$ 没有包含所有 $[l,\,r]$ 的点,那么答案一定是 $x$ 到 $p$ 的距离。
至此答案是什么就讨论完了,每次相当于倍增,判定 $u$ 子树内有多少个 $[l,\,r]$ 内的点。
持久化线段树合并即可。时间复杂度 $\Theta(n \log^2 n)$,空间复杂度 $\Theta(n \log n)$,可以通过。
写完代码直接 1A1C 了,舒服。
这个做法基本没有思维难度,代码难度也很低。
```cpp
#include <bits/stdc++.h>
#define X first
#define Y second
#define rep(i, a, b) for (int i = a; i <= b; i++)
#define per(i, a, b) for (int i = a; i >= b; i--)
#define pb push_back
#define mp make_pair
#define mid (l + r >> 1)
using namespace std;
typedef long long int ll;
using ull = unsigned long long int;
using pii = pair<int, int>;
constexpr int maxn = 2e5 + 10, mx = 2e6 + 5, mod = 998244353;
struct edge { int to, nxt, w; } nd[maxn << 1]; int h[maxn], cnt = 0, n, Q, lg[maxn], ans;
inline void add(int u, int v, int w) { nd[cnt].nxt = h[u], nd[cnt].to = v, nd[cnt].w = w, h[u] = cnt++; }
struct Node { int l, r, v; } t[maxn * 41];
#define ls(x) (t[x].l)
#define rs(x) (t[x].r)
#define val(x) (t[x].v)
#define mid (l + r >> 1)
int fa[maxn][18], a[18][maxn], d[maxn], s[maxn], rt[maxn], tot = 0;
void mg(int& x, int& y) {
if (!x || !y) x |= y;
else {
int u = ++tot; t[u] = t[x]; x = u; val(u) += val(y);
mg(ls(u), ls(y)), mg(rs(u), rs(y));
}
}
void mdf(int l, int r, int k, int p, int& x) {
t[x = ++tot] = t[p]; ++val(x); if (l == r) return void();
return k <= mid ? mdf(l, mid, k, ls(p), ls(x)) : mdf(mid + 1, r, k, rs(p), rs(x));
}
int qry(int l, int r, int ql, int qr, int x) {
if (!x || ql <= l && r <= qr) return val(x); int sum = 0;
(ql <= mid) && (sum += qry(l, mid, ql, qr, ls(x))), (qr > mid) && (sum += qry(mid + 1, r, ql, qr, rs(x))); return sum;
}
void dfs(int u, int f) {
d[u] = d[fa[u][0] = f] + 1;
rep(i, 1, 17) fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = h[u]; ~i; i = nd[i].nxt) {
int v = nd[i].to;
if (v != f) {
s[v] = s[u] + nd[i].w, dfs(v, u);
mg(rt[u], rt[v]);
}
}
mdf(1, n, u, rt[u], rt[u]);
}
inline int lca(int u, int v) {
if (d[u] < d[v]) swap(u, v);
rep(i, 0, 17) if (d[u] - d[v] >> i & 1) u = fa[u][i];
if (u == v) return u;
per(i, 17, 0) if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
inline int ck(int u, int v) { return d[u] < d[v] ? u : v; }
inline int LCA(int l, int r) { if (l == r) return l; int g = lg[r-- - l]; return ck(a[g][l], a[g][r - (1 << g) + 1]); }
#define F(l, r, x) qry(1, n, l, r, rt[x])
inline void solve(int p, int l, int r) {
int W = 0;
if ((W = F(l, r, p)) == r - l + 1) return void(ans = s[LCA(l, r)] - s[p]);
else {
if (W) return void(ans = 0); int x = p;
per(i, 17, 0) if (fa[x][i] && !F(l, r, fa[x][i])) x = fa[x][i]; x = fa[x][0];
if (F(l, r, x) == r - l + 1) {
int u = LCA(l, r);
return void(ans = s[p] + s[u] - 2 * s[lca(p, u)]);
}
else return void(ans = s[p] - s[x]);
}
}
int main() {
memset(h, -1, sizeof(h)); scanf("%d%d", &n, &Q);
for (int i = 2, u, v, w; i <= n; i++) scanf("%d%d%d", &u, &v, &w), add(u, v, w), add(v, u, w), lg[i] = lg[i >> 1] + 1;
dfs(1, 0); rep(i, 1, n - 1) a[0][i] = lca(i, i + 1);
rep(i, 1, 17) rep(j, 1, n - (1 << i - 1)) a[i][j] = ck(a[i - 1][j], a[i - 1][j + (1 << i - 1)]);
//cout << F(8, 9, 4) << "\n";
while (Q--) {
int p, l, r;
scanf("%d%d%d", &p, &l, &r); p ^= ans, l ^= ans, r ^= ans;
solve(p, l, r); printf("%d\n", ans);
}
return 0;
}
```