拋磚引玉.jpeg
_______
大概就是如果没有修改操作的话,就是比较裸的点分树。于是先考虑没有修改操作的情况。
考虑怎么维护这个东西,自然是希望对每个点都记录一个桶,但这样显然由于每个点的深度不可控,最终需要的空间代价是 $O(n^2)$ 的。于是考虑怎么调整树的高度使得最终总的空间複杂度可以接受,那自然就会想到点分治。注意到点分治时,每个点在分治过程中,『逻辑树高』都只有 $\log n$ 。这大概就是为什么用点分树的原因。
所以就是建出点分树来,每个点维护一个 `vector` 作为桶,维护点分树上子树内到当前点距离为 $k$ 的点权和。这样对于询问,每次只需要跳点分树,然后对于每个 $fa$ 统计 $k-dis(fa,x)$ 的点对的数量就好了。但是还有一个问题,就是对于以当前 $fa$ 为根的那些子树,在算下一个 $fa$ 的时候会被算重。于是就要再维护一个桶,表示 $x$ 子树内的点,到点分树上 $x$ 的父亲的距离为 $k$ 的点权和。由于边权都为 $1$ ,这个操作就会很方便。
考虑如果带修改,那无非就是把桶换成树状数组即可。这样複杂度就会是 $O(m\log ^2 n)$ 的了。可能我写的比较丑?预处理是常数不小的 $O(n\log ^2 n)$ ,似乎比其他人都慢诶…
以下是第一次写点分树相关题的人可能会遇到的 bug:
1、最开始的时候维护的是 **点分树** 上距离为 $k$ 的点的点权和。
2、然后改了改,但是查询的时候没有维护两个 BIT,只维护了一个,然后减去的是查询 $x$ 的点分树子树内到 $x$ 距离 $\leq k-2\times dis(fa_x,x)$ 的点权和。看上去有点东西,但问题在于到 $x$ 距离和到 $fa_x$ 距离没有本质上的关係…比如可以在树的对侧。
3、最后还是写了两个 BIT,但是调了很久,原因是向上跳遇到 $dis(fa_x,x)>k$ 应该 `continue` 而不是 `break` ,因为这距离并是实际距离,在点分树上没有单调性。
```cpp
#include <bits/stdc++.h>
using namespace std ;
#define il inline
#define to(k) E[k].to
#define next(k) E[k].next
#define low(x) (x & (-x))
const int N = 300010 ;
void debug(int *tp, int minn, int maxn, char c){
for (int i = minn ; i <= maxn ; ++ i)
cout << tp[i] << " " ; cout << c ;
}
int res ;
int ans ;
int lans ;
int n, m ;
int f[N] ;
int d[N] ;
int vis[N] ;
int dep[N] ;
int mx_dep ;
struct Edge{
int to ;
int next ;
}E[N * 2] ;
int cnt ;
int base[N] ;
int head[N] ;
unordered_map<int, int> Id[N] ;
vector <int> sub[N] ;
vector <int> buc[N] ;
il void add(int a, int b){
to(++ cnt) = b ;
next(cnt) = head[a] ;
head[a] = cnt ;
}
namespace findCG{
int grt ;
int num ;
int g[N] ;
int size[N] ;
il void chk(int &a, int b){
a = b <= a ? a : b ;
}
il void reset(){
g[grt = 0] = 19690126 ;
}
void dfs(int x, int fa){
size[x] = 1 ; g[x] = 0 ;
for (int k = head[x] ; k ; k = next(k)){
if (to(k) != fa && !vis[to(k)]){
dfs(to(k), x) ;
size[x] += size[to(k)] ;
g[x] = max(g[x], size[to(k)]) ;
}
}
chk(g[x], num - size[x]) ;
if (g[x] < g[grt]) grt = x ;
}
}
using namespace findCG ;
il void init(int root, int x){
for (int i = 0 ; i <= x ; ++ i)
buc[root].push_back(0) ;
}
il void add(int root, int x, int p){
int t = buc[root].size() ;
for ( ; x < t ; x += low(x)) buc[root][x] += p ;
}
il int ask(int root, int x){
int ret = 0 ;
if (x >= buc[root].size())
x = (int)buc[root].size() - 1 ;
for ( ; x ; x -= low(x)) ret += buc[root][x] ;
return ret ;
}
il void init2(int root, int x){
for (int i = 0 ; i <= x ; ++ i)
sub[root].push_back(0) ;
}
il void add2(int root, int x, int p){
int t = sub[root].size() ;
for ( ; x < t ; x += low(x)) sub[root][x] += p ;
}
il int ask2(int root, int x){
int ret = 0 ;
if (x >= sub[root].size())
x = (int)sub[root].size() - 1 ;
for ( ; x ; x -= low(x)) ret += sub[root][x] ;
return ret ;
}
void calc(int x, int fa, int root){
size[x] = 1 ;
dep[x] = dep[fa] + 1 ;
Id[root][x] = dep[x] ;
for (int i = head[x] ; i ; i = next(i))
if (!vis[to(i)] && to(i) != fa)
calc(to(i), x, root), size[x] += size[to(i)] ;
mx_dep = max(dep[x], mx_dep) ;
}
void calc2(int x, int fa, int root, int frt){
add(root, dep[x], base[x]) ;
if (frt) add2(root, Id[frt][x], base[x]) ;
for (int i = head[x] ; i ; i = next(i))
if (!vis[to(i)] && to(i) != fa) calc2(to(i), x, root, frt) ;
}
void find_tree(int x, int fa, int h){
int mx ; vis[x] = 1 ; mx_dep = 0 ;
calc(x, 0, x), init(x, mx_dep) ;
init2(x, h) ; mx = mx_dep ;
calc2(x, 0, x, fa) ;
for (int k = head[x] ; k ; k = next(k)){
if (vis[to(k)]) continue ;
num = size[to(k)] ; reset() ;
dfs(to(k), x) ; f[grt] = x ;
find_tree(grt, x, mx) ;
}
}
il int qr(){
int r = 0 ;
char c = getchar() ;
while (!isdigit(c)) c = getchar() ;
while (isdigit(c))
r = r * 10 + c - 48, c = getchar() ;
return r ;
}
int main(){
int a, b, c ; cin >> n >> m ;
for (int i = 1 ; i <= n ; ++ i) base[i] = qr() ;
for (int i = 1 ; i < n ; ++ i)
a = qr(), b = qr(), add(a, b), add(b, a) ;
reset() ; num = n ; dfs(1, 0) ; find_tree(grt, 0, 0) ;
while (m --){
a = qr() ;
b = qr() ^ lans ;
c = qr() ^ lans ;
if (!a){
int fb = f[b] ;
int ob, lb = b, df ;
ans += ask(lb, c + 1) ;
while (fb){
df = Id[fb][b] - 1 ;
if (c - df < 0){
lb = fb, fb = f[fb] ;
continue ;
}
ans += ask(fb, c - df + 1) ;
ans -= ask2(lb, c - df + 1) ;
lb = fb, fb = f[fb] ;
}
printf("%d", (lans = ans)) ;
ans = 0 ; putchar('\n') ;
}
else {
int ob = b ;
add(b, 1, -base[b] + c) ;
while (f[b]){
int df = Id[f[b]][ob] ;
add(f[b], df, -base[ob] + c) ;
add2(b, df, -base[ob] + c) ; b = f[b] ;
}
base[ob] = c ;
}
}
return 0 ;
}
```
~~為什么我這麼慢啊~~