题解 CF983E NN country

Acfboy

2021-07-11 11:54:00

Solution

这题代码其实用不着百来行,如果用主席树实现的话。 遇到树上问题,首先想一想链上的情况。 如果输入是一条链,我们向某个方向跳的时候一定是要尽可能坐车到远的地方,每次暴力找最远的肯定不行,容易想到可以倍增优化。 现在把链上的情况搬到树上。 树上仍然是链的情况,因为要求最少乘车,而在树上两点之间不重复经过点的路径是唯一的,所以只要考虑询问的两个点之间的那条链就可以了。 这里同样可以倍增,把路分成两段,然后两点往它们的 LCA 跳。可是有个问题,万一这个路跳过 LCA 了怎么办? 只需要判断一下是否有一条路能跨过 LCA 就好了! 也就是说要判断有没有一辆车的线路的俩端点分别在两个查询的点在 LCA 上一次跳到的点的两个子树内,这就变成了二维数点问题,可以用扫描线来实现,但是我推荐主席树,因为似乎短一些。具体地,对每个端点点按 dfs 序在权值线段树插入另一个端点,查询时只需要查询一个端点的子树 dfs 区间内的值有是否大于 $0$ 即可。 其实出题人为了让我们代码少写一点还是做了些什么的,比如题目里保证了一个点的父亲一定比它小,这样从大到小就可以避免父亲的影响。 还有个小 trick, 在一棵树上多次插入的时候可以插多棵树,最后把根指向最后那棵,这样就不用在写个修改函数了。 代码。 ```cpp #include <cstdio> #include <vector> #include <algorithm> const int N = 200005; int n, m, cnt, tim, f[N][20], d[N], g[N][20], rt[N], p[N], q[N]; std::vector<int> G[N], bus[N]; struct twt { int ls, rs, s; } t[N*40]; void dfs(int u) { p[u] = ++tim, d[u] = d[f[u][0]] + 1; for (int i = 0; i < (signed)G[u].size(); i++) dfs(G[u][i]); q[u] = tim; } int lca(int x, int y) { if (d[x] < d[y]) std::swap(x, y); int t = d[x] - d[y]; for (int i = 0; i <= 19; i++) if (t & (1 << i)) x = f[x][i]; if (x == y) return x; for (int i = 19; ~i; i--) if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; return f[x][0]; } void add(int &rt, int l, int r, int x) { t[++cnt] = t[rt], t[cnt].s ++; rt = cnt; if (l == r) return; int mid = l + (r-l) / 2; if (x <= mid) add(t[rt].ls, l, mid, x); else add(t[rt].rs, mid+1, r, x); } int Query(int rt, int l, int r, int x, int y) { if (!rt) return 0; if (l == x && r == y) return t[rt].s; int mid = l + (r-l) / 2; if (y <= mid) return Query(t[rt].ls, l, mid, x, y); else if (x > mid) return Query(t[rt].rs, mid+1, r, x, y); else return Query(t[rt].ls, l, mid, x, mid) + Query(t[rt].rs, mid+1, r, mid+1, y); } int main() { scanf("%d", &n); for (int i = 2; i <= n; i++) scanf("%d", &f[i][0]), G[f[i][0]].push_back(i); for (int j = 1; j <= 19; j++) for (int i = 1; i <= n; i++) f[i][j] = f[f[i][j-1]][j-1]; dfs(1); scanf("%d", &m); for (int i = 1; i <= n; i++) g[i][0] = i; for (int i = 1, x, y; i <= m; i++) { scanf("%d%d", &x, &y); int l = lca(x, y); g[x][0] = std::min(g[x][0], l), g[y][0] = std::min(g[y][0], l); bus[p[x]].push_back(p[y]), bus[p[y]].push_back(p[x]); } for (int i = n; i >= 1; i--) g[f[i][0]][0] = std::min(g[f[i][0]][0], g[i][0]); for (int j = 1; j <= 19; j++) for (int i = 1; i <= n; i++) g[i][j] = g[g[i][j-1]][j-1]; for (int i = 1; i <= n; i++) { rt[i] = rt[i-1]; for (int j = 0; j < (signed)bus[i].size(); j++) add(rt[i], 1, n, bus[i][j]); } scanf("%d", &m); for (int i = 1, u, v; i <= m; i++) { scanf("%d%d", &u, &v); if (u == v) { puts("0"); continue; } int l = lca(u, v); if (g[u][19] > l || g[v][19] > l) { puts("-1"); continue; } int ans = 2; for (int i = 19; i >= 0; i--) if (g[u][i] > l) u = g[u][i], ans += (1 << i); for (int i = 19; i >= 0; i--) if (g[v][i] > l) v = g[v][i], ans += (1 << i); if (l == u || l == v) ans --; else ans -= !!(Query(rt[q[u]], 1, n, p[v], q[v]) - Query(rt[p[u]-1], 1, n, p[v], q[v])); printf("%d\n", ans); } return 0; } ```