P6329 点分树 | 震波

皎月半洒花

2020-04-09 21:53:28

Solution

拋磚引玉.jpeg _______ 大概就是如果没有修改操作的话,就是比较裸的点分树。于是先考虑没有修改操作的情况。 考虑怎么维护这个东西,自然是希望对每个点都记录一个桶,但这样显然由于每个点的深度不可控,最终需要的空间代价是 $O(n^2)$ 的。于是考虑怎么调整树的高度使得最终总的空间複杂度可以接受,那自然就会想到点分治。注意到点分治时,每个点在分治过程中,『逻辑树高』都只有 $\log n$ 。这大概就是为什么用点分树的原因。 所以就是建出点分树来,每个点维护一个 `vector` 作为桶,维护点分树上子树内到当前点距离为 $k$ 的点权和。这样对于询问,每次只需要跳点分树,然后对于每个 $fa$ 统计 $k-dis(fa,x)$ 的点对的数量就好了。但是还有一个问题,就是对于以当前 $fa$ 为根的那些子树,在算下一个 $fa$ 的时候会被算重。于是就要再维护一个桶,表示 $x$ 子树内的点,到点分树上 $x$ 的父亲的距离为 $k$ 的点权和。由于边权都为 $1$ ,这个操作就会很方便。 考虑如果带修改,那无非就是把桶换成树状数组即可。这样複杂度就会是 $O(m\log ^2 n)$ 的了。可能我写的比较丑?预处理是常数不小的 $O(n\log ^2 n)$ ,似乎比其他人都慢诶… 以下是第一次写点分树相关题的人可能会遇到的 bug: 1、最开始的时候维护的是 **点分树** 上距离为 $k$ 的点的点权和。 2、然后改了改,但是查询的时候没有维护两个 BIT,只维护了一个,然后减去的是查询 $x$ 的点分树子树内到 $x$ 距离 $\leq k-2\times dis(fa_x,x)$ 的点权和。看上去有点东西,但问题在于到 $x$ 距离和到 $fa_x$ 距离没有本质上的关係…比如可以在树的对侧。 3、最后还是写了两个 BIT,但是调了很久,原因是向上跳遇到 $dis(fa_x,x)>k$ 应该 `continue` 而不是 `break` ,因为这距离并是实际距离,在点分树上没有单调性。 ```cpp #include <bits/stdc++.h> using namespace std ; #define il inline #define to(k) E[k].to #define next(k) E[k].next #define low(x) (x & (-x)) const int N = 300010 ; void debug(int *tp, int minn, int maxn, char c){ for (int i = minn ; i <= maxn ; ++ i) cout << tp[i] << " " ; cout << c ; } int res ; int ans ; int lans ; int n, m ; int f[N] ; int d[N] ; int vis[N] ; int dep[N] ; int mx_dep ; struct Edge{ int to ; int next ; }E[N * 2] ; int cnt ; int base[N] ; int head[N] ; unordered_map<int, int> Id[N] ; vector <int> sub[N] ; vector <int> buc[N] ; il void add(int a, int b){ to(++ cnt) = b ; next(cnt) = head[a] ; head[a] = cnt ; } namespace findCG{ int grt ; int num ; int g[N] ; int size[N] ; il void chk(int &a, int b){ a = b <= a ? a : b ; } il void reset(){ g[grt = 0] = 19690126 ; } void dfs(int x, int fa){ size[x] = 1 ; g[x] = 0 ; for (int k = head[x] ; k ; k = next(k)){ if (to(k) != fa && !vis[to(k)]){ dfs(to(k), x) ; size[x] += size[to(k)] ; g[x] = max(g[x], size[to(k)]) ; } } chk(g[x], num - size[x]) ; if (g[x] < g[grt]) grt = x ; } } using namespace findCG ; il void init(int root, int x){ for (int i = 0 ; i <= x ; ++ i) buc[root].push_back(0) ; } il void add(int root, int x, int p){ int t = buc[root].size() ; for ( ; x < t ; x += low(x)) buc[root][x] += p ; } il int ask(int root, int x){ int ret = 0 ; if (x >= buc[root].size()) x = (int)buc[root].size() - 1 ; for ( ; x ; x -= low(x)) ret += buc[root][x] ; return ret ; } il void init2(int root, int x){ for (int i = 0 ; i <= x ; ++ i) sub[root].push_back(0) ; } il void add2(int root, int x, int p){ int t = sub[root].size() ; for ( ; x < t ; x += low(x)) sub[root][x] += p ; } il int ask2(int root, int x){ int ret = 0 ; if (x >= sub[root].size()) x = (int)sub[root].size() - 1 ; for ( ; x ; x -= low(x)) ret += sub[root][x] ; return ret ; } void calc(int x, int fa, int root){ size[x] = 1 ; dep[x] = dep[fa] + 1 ; Id[root][x] = dep[x] ; for (int i = head[x] ; i ; i = next(i)) if (!vis[to(i)] && to(i) != fa) calc(to(i), x, root), size[x] += size[to(i)] ; mx_dep = max(dep[x], mx_dep) ; } void calc2(int x, int fa, int root, int frt){ add(root, dep[x], base[x]) ; if (frt) add2(root, Id[frt][x], base[x]) ; for (int i = head[x] ; i ; i = next(i)) if (!vis[to(i)] && to(i) != fa) calc2(to(i), x, root, frt) ; } void find_tree(int x, int fa, int h){ int mx ; vis[x] = 1 ; mx_dep = 0 ; calc(x, 0, x), init(x, mx_dep) ; init2(x, h) ; mx = mx_dep ; calc2(x, 0, x, fa) ; for (int k = head[x] ; k ; k = next(k)){ if (vis[to(k)]) continue ; num = size[to(k)] ; reset() ; dfs(to(k), x) ; f[grt] = x ; find_tree(grt, x, mx) ; } } il int qr(){ int r = 0 ; char c = getchar() ; while (!isdigit(c)) c = getchar() ; while (isdigit(c)) r = r * 10 + c - 48, c = getchar() ; return r ; } int main(){ int a, b, c ; cin >> n >> m ; for (int i = 1 ; i <= n ; ++ i) base[i] = qr() ; for (int i = 1 ; i < n ; ++ i) a = qr(), b = qr(), add(a, b), add(b, a) ; reset() ; num = n ; dfs(1, 0) ; find_tree(grt, 0, 0) ; while (m --){ a = qr() ; b = qr() ^ lans ; c = qr() ^ lans ; if (!a){ int fb = f[b] ; int ob, lb = b, df ; ans += ask(lb, c + 1) ; while (fb){ df = Id[fb][b] - 1 ; if (c - df < 0){ lb = fb, fb = f[fb] ; continue ; } ans += ask(fb, c - df + 1) ; ans -= ask2(lb, c - df + 1) ; lb = fb, fb = f[fb] ; } printf("%d", (lans = ans)) ; ans = 0 ; putchar('\n') ; } else { int ob = b ; add(b, 1, -base[b] + c) ; while (f[b]){ int df = Id[f[b]][ob] ; add(f[b], df, -base[ob] + c) ; add2(b, df, -base[ob] + c) ; b = f[b] ; } base[ob] = c ; } } return 0 ; } ``` ~~為什么我這麼慢啊~~