树链剖分良心讲解

· · 题解

树上路径,分清边孰轻孰重。

——《我的一个OIer朋友》

树链剖分,简称树剖。是一种巧妙的算法,可以把树划分成许多链,简便地实现树上的修改与查询操作。

(这里说的 树链剖分 指的是 轻重链剖分)
↑ 如果你初学树剖可以无视这句话。

在接下来的讲解中,您将不可避免地接触以下名词:

  1. 重儿子:对于一个非叶子节点 u,它会有许多儿子。其中有一个儿子的孩子最多(也就是子树最大),那么这个儿子就叫 u 的重儿子;
  2. 轻儿子:不是重儿子,就是轻儿子;
  3. 重链:一条全部以重儿子为节点的路径,除了顶部为轻儿子。

此处名词解释相对其它很多 blog 有所简化,不过不影响阅读。
纯文字的描述好像不太方便,那我们拿张图来:

图中,重儿子、重链都用红色标出来了。
我们可以看到,对于节点 1 来说,它的儿子 2 号节点子树最大,因此 21 的重儿子。
对于 2 号点,它的儿子 6 号节点子树最大,因此 62 的重儿子。
对于其它节点同理。如果想加深印象,可以试着手动模拟一下。

接着就是重头戏了:如何代码实现树剖

树剖的实际上就是两遍预处理,第一步要算出以下数据:
对于任意节点 u深度子树大小重儿子编号父节点编号;分别记为:\text{depth , size , son , fa}
这点用一个简单的 DFS 就可以实现,时间复杂度 \Theta(n)

void dfs1(int u,int f){ //u为当前节点,f为父节点
    fa[u] = f;
    size[u] = 1; //子树大小要算上子树的根节点,也就是u
    depth[u] = depth[f]+1; //比父亲深度大1
    int v,t = -1,l = adj[u].size(); //此处使用vector存图
    for(int i=0;i<l;++i){ //遍历连接u的点v
        v = adj[u][i];
        if(v==f) continue;
        dfs1(v,u);
        size[u] += size[v];
        if(size[v]>t){ 
            //如果这个子树大小比已找到的还大,那就更新已找到的,同时更新u的重儿子为v
            t = size[v];
            son[u] = v;
        }
    }
}

第二遍预处理,则需要计算:

对于任意节点 u所在重链顶点第几个被遍历,分别记为 \text{top , id}
时间复杂度也是线性的。

void dfs2(int u,int f){ //这里f就不是指父亲了,是u所在重链的顶端节点
    top[u] = f;
    id[u] = ++cnt;
    if(w[u]!=0)
        add(id[u],id[u],w[u]); //树状数组维护区间和,w[u]为u的初始权值
    if(son[u]==0) return; //重儿子编号为0意味着没有儿子,返回
    dfs2(son[u],f); //先从重儿子开始dfs,这样可以使一条重链上的节点id连续,便于区间操作
    int v,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==son[u]||v==fa[u]) continue;
        dfs2(v,v); //由于是轻儿子,所以其所在重链顶端节点自然是自己
    }
}

然后你就学会了树剖
并不是,你还需要实现修改和查询。
在考虑这个问题之前,我们先来看一下刚才预处理的成果吧:
现在把前面贴的图改了一下,逗号后面的数表示这个节点的 \text{id}

来看一下各点的 \text{id},你可以发现两个性质:

  1. 一条重链上的点 \text{id} 连续;
  2. 一棵子树上的点 \text{id} 也连续。

接下来就要利用这个性质搞定那四个操作啦!
来看第一个操作:

将树从 xy 结点最短路径上所有节点的值都加上 z

我们可以回想一下当初倍增求 LCA 的时候,是怎么办的?
就是两个节点,深度大的往上跳,一直跳到父亲相同。
这里的思想也和求 LCA 很相似。

由于每条重链上的点 \text{id} 连续,所以我们可以直接调用

add(id[top[u]],id[u],k);

就可以把 u 到其重链顶点的节点值都增加 k
(这里 add 就是区间加操作)

加完了之后,我们再执行

u = fa[top[u]];

这样就 u 跳到其重链顶上去了。两个节点按这样的方式不断向上跳,直到一条重链上。
于是操作 1 就完成了,完整代码如下:

void addPath(int u,int v,int k){
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]])  
            swap(u,v); //深度大的点先跳,保证能跳到一条重链上
        add(id[top[u]],id[u],k);    
        u = fa[top[u]];
    }
    if(depth[u]>depth[v]) swap(u,v);
    add(id[u],id[v],k); //在一条重链上,直接加
}

操作 2 只是大同小异,区间修改变成了区间查询而已。

int queryPath(int u,int v){
    int res = 0;
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]])
            swap(u,v);
        res += query(id[top[u]],id[u]);
        u = fa[top[u]];
    }
    if(depth[u]>depth[v]) swap(u,v);
    res += query(id[u],id[v]);
    return res;
}

接下来看操作3:

将以 x 为根节点的子树内所有节点值都加上 z

关于子树的操作其实非常简单。
题目的修改和查询都是对整个子树的。又由于子树的节点 \text{id} 是连续的,所以该怎么办已经很清楚了吧。

修改和查询依然很相似,直接放在一起贴出来:

int querySon(int u){
    return query(id[u],id[u]+size[u]-1);
    //id[u]到id[u]+size[u]-1刚好涵盖了u的所有子节点id,下同
}

void addSon(int u,int k){
    add(id[u],id[u]+size[u]-1,k);
}

然后这题代码就出来了,但是题目中多了个取模,注意要加上。

总时间复杂度为 \Theta(n\log^2 n)

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#define N 100003
#define inf 0x3f3f3f3f
#define ll long long
using namespace std;

vector<int> adj[N];
int n,m,r,p,cnt;
int son[N],depth[N],fa[N],size[N];
int id[N],top[N],w[N];
ll c1[N],c2[N];

//以下为树状数组
inline int lowbit(int x){
    return x&(-x);
}

inline void add(int l,int r,int x){
    x %= p;
    int ad1 = (ll)(l-1)*x%p;
    int ad2 = (ll)r*x%p;
    for(int t=l;t<=n;t+=lowbit(t)){
        c1[t] = (c1[t]+x)%p;
        c2[t] = (c2[t]+ad1)%p;
    }
    for(int t=r+1;t<=n;t+=lowbit(t)){
        c1[t] = (c1[t]-x)%p;
        c1[t] = (c1[t]+p)%p;
        c2[t] = (c2[t]-ad2)%p;
        c2[t] = (c2[t]+p)%p;
    }
}

inline int qwq(int i){ //qwq
    int res = 0;
    for(int t=i;t>0;t-=lowbit(t)){
        res = (res+(ll)i*c1[t]%p)%p;
        res = (res-c2[t])%p;
        res = (res+p)%p;
    }
    return res;
}

inline int query(int l,int r){
    int res = (qwq(r)-qwq(l-1))%p;
    return (res+p)%p;
}
//以上树状数组

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

void print(int x){
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

void dfs1(int u,int f){
    fa[u] = f;
    size[u] = 1;
    depth[u] = depth[f]+1;
    int v,t = -1,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==f) continue;
        dfs1(v,u);
        size[u] += size[v];
        if(size[v]>t){
            t = size[v];
            son[u] = v;
        }
    }
}

void dfs2(int u,int f){
    top[u] = f;
    id[u] = ++cnt;
    if(w[u]!=0)
        add(id[u],id[u],w[u]);
    if(son[u]==0) return;
    dfs2(son[u],f);
    int v,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==son[u]||v==fa[u]) continue;
        dfs2(v,v);
    }
}

int queryPath(int u,int v){
    int res = 0;
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]])
            swap(u,v);
        res = (res+query(id[top[u]],id[u]))%p;
        u = fa[top[u]];
    }
    if(depth[u]>depth[v]) swap(u,v);
    res = (res+query(id[u],id[v]))%p;
    return res;
}

void addPath(int u,int v,int k){
    k %= p;
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]])
            swap(u,v);
        add(id[top[u]],id[u],k);    
        u = fa[top[u]];
    }
    if(depth[u]>depth[v]) swap(u,v);
    add(id[u],id[v],k);
}

int querySon(int u){
    return query(id[u],id[u]+size[u]-1);
}

void addSon(int u,int k){
    k %= p;
    add(id[u],id[u]+size[u]-1,k);
}

int main(){
    int u,v;
    read(n),read(m),read(r),read(p);
    for(int i=1;i<=n;++i)
        read(w[i]);
    for(int i=1;i<n;++i){
        read(u),read(v);
        adj[u].push_back(v);
        adj[v].push_back(u);
    }    
    dfs1(r,0);
    dfs2(r,r);
    int ans,op,x,y,z;
    while(m--){
        read(op),read(x);
        if(op==1){
            read(y),read(z);
            addPath(x,y,z);
            continue;
        }
        if(op==2){
            read(y);
            ans = queryPath(x,y);
            print(ans);
            putchar('\n');
            continue;
        }
        if(op==3){
            read(z);
            addSon(x,z);
            continue;
        }
        ans = querySon(x);
        print(ans);
        putchar('\n');
    }
    return 0;
}

此外,树剖还可以用来求LCA,代码比倍增法短很多:

inline int lca(int u,int v){
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]])
            swap(u,v);
        u = fa[top[u]];    
    }
    if(depth[u]<depth[v]) return u;
    return v;
}

是不是超级简单