P10604 BZOJ4317 Atm 的树 题解

· · 题解

分析

直接求第 k 小实在比较难,考虑二分答案,即转化为:给定 u,v,求 u 出发的路径中,有多少条权值和 \leqslant v

——这不是点分树模板?

同样地,在点分树上,对于点 i,暴力记录下点分树上点 i 子树中各点至 i 的路径长。

对于查询“与 i 距离 \leqslant d 的点数”,先在 i 点分树上子树中统计(可以对子树信息排序后用 upper_bound 求),接下来考虑该子树外的贡献。

跳祖先时,若从 f 跳到 pp 子树内的贡献即为 p 子树内与 p 距离 \leqslant d-\operatorname{dist}(p,i) 的点数,减去 f 子树内与 p 距离 \leqslant d-\operatorname{dist}(p,i) 的点数。(点 f 子树内的点不应在 p 的子树中再算一遍)

关于对每个点 $i$ 直接分别存子树内点至 $i$ 与 $f_i$($i$ 的父亲)的距离,因为每个点只会被其祖先各自存 $2$ 次,所以总共也只有 $O(n\log n)$ 级别的空间,开 `vector` 存即可。 ### 代码 ```cpp //...... int choose(int a, int b) {return dfn[a] < dfn[b]? a: b;} void dfs(int k, int pre) { st[0][dfn[k] = ++cnt] = pre; for(auto i : g[k]) if(i.to != pre) dis[i.to] = dis[k] + i.value, dfs(i.to, k); } void init() { dfs(1, 0); for(int i = 1; i < maxl; i++) for(int j = 1; j + (1 << i) - 1 <= n; j++) st[i][j] = choose(st[i - 1][j], st[i - 1][j + (1 << i - 1)]); } int lca(int a, int b) { if(a == b) return a; if(dfn[a] > dfn[b]) a ^= b ^= a ^= b; int len = log2(dfn[b] - dfn[a]); return choose(st[len][dfn[a] + 1], st[len][dfn[b] - (1 << len) + 1]); }//上为 dfn 序求 LCA int distance(int a, int b) {return dis[a] + dis[b] - 2 * dis[lca(a, b)];} void calcsize(int k, int pre, int mode) { int mx = 0; sz[k] = 1; for(auto i : g[k]) if(i.to != pre && !del[i.to]) calcsize(i.to, k, mode), mx = max(mx, sz[i.to]), sz[k] += sz[i.to]; mx = max(mx, nown - sz[k]); if(mode && rtmx > mx) rtmx = mx, rt = k; }//算重心 void build(int k) { del[k] = true; for(auto i : g[k]) if(!del[i.to]) { rtmx = nown = sz[i.to]; calcsize(i.to, -1, 1), calcsize(rt, -1, 0); f[rt] = k, build(rt); } }//建点分树 int query(int cent, int range) { int ret = std::upper_bound(f1[cent].begin(), f1[cent].end(), range) - f1[cent].begin(); for(int i = f[cent], pre = cent; i; i = f[pre = i]) { ret += std::upper_bound(f1[i].begin(), f1[i].end(), range - distance(i, cent)) - f1[i].begin(); ret -= std::upper_bound(f2[pre].begin(), f2[pre].end(), range - distance(i, cent)) - f2[pre].begin();//算贡献时,上一行本不应该包含 pre 子树中的点,所以以 f2 统计算父亲时要减去的贡献 } return ret - 1;//题中的路径不包含单点 } int main() { rtmx = nown = n = read(), rk = read(); for(int i = 1, u, v, w; i < n; i++) { u = read(), v = read(), w = read(); g[u].push_back({v, w}), g[v].push_back({u, w}); } calcsize(1, -1, 1), calcsize(rt, -1, 0), build(rt); init(); for(int i = 1; i <= n; i++) { for(int pos = i; pos; pos = f[pos]) { f1[pos].push_back(distance(pos, i)); if(f[pos]) f2[pos].push_back(distance(f[pos], i)); } } for(int i = 1; i <= n; i++) std::sort(f1[i].begin(), f1[i].end()), std::sort(f2[i].begin(), f2[i].end()); for(int i = 1; i <= n; i++) { int l = 0, r = n * 10; while(l < r) { int mid = l + r >> 1; if(query(i, mid) >= rk) r = mid; else l = mid + 1; } printf("%d\n", l); } return 0; } ```