yltx's blog

yltx's blog

Κοιτάζοντας πάνω στον έναστρο ουρανό, κάτω στη γη.

浅谈树链剖分

posted on 2019-09-01 17:20:23 | under 题解 |

更好的阅读体验?https://yltx.cf/2019/09/01/%E6%B5%85%E8%B0%88%E6%A0%91%E9%93%BE%E5%89%96%E5%88%86/#more

引子

在OI中,有时候我们会需要处理一些树上的链的问题

比方说,给定一棵 $n$ 个点的树, $m$ 个操作,每次查询 $x$ 和 $y$ 之间的链上的和

不要在意到底是什么操作,这真的只是个引子

考虑做法。

  1. level 1

    $ n,m \leq 100 $

    那么我们直接暴力求即可。复杂度 $O(nm)$

  2. level 2

    $ n\leq 1000 , m \leq 100000$

    很明显不能直接暴力了。

    询问过多,但是并没有修改操作,所以可以考虑 $O(n^2)$ 把 $x$ 和 $y$ 之间的和预处理出来,存起来,直接访问。

  3. level 3

    $ n\leq10^5,m\leq10^5$ ,并且要求支持修改操作

    这才是我们今天要讨论的问题

    为了支持 $O(logn)$ 的修改, $O(logn)$ 的查询,我们发展了这种叫树链剖分的东西。

树链剖分

前置芝士

  • DFS

  • 线段树

  • 链式前向星

定义

  • 重儿子:一个点有多个儿子,定义其儿子中子树大小最大的儿子为重儿子。

  • 轻儿子:一个点不是重儿子的儿子都是轻儿子。

  • 重边:一个点与其重儿子之间的边。

  • 轻边:一个点与其轻儿子之间的边。

  • 重链:完全由重边连成的链。

  • 重链的顶端:一条重链上深度最小(最靠近根)的点。

  • 特别的,我们为了代码的舒适性人为定义一个轻儿子为一条长度为1的重链。

树链剖分能做什么?

  • 解决一条链上的信息查询/修改问题

  • 其他链上线段树能维护的东西

实现

存树

这个实际上不难,所有数字直接先丢进线段树的数组,边的话直接读入的时候链式前向星存下来就好。

考虑到是双向边,要加两次。

inline void add(int u,int v){//链式前向星加边
    to[++bian]=v;//记下现在第bian条边所指向的节点
    nxt[bian]=beg[u];//指针指向u的链表原来的表头
    beg[u]=bian;//更新表头
}
scanf ("%lld%lld%lld%lld",&n,&m,&r,&mod);//读入
fa[r]=0,dep[r]=1;//r是根节点,根节点的父亲是0,深度的1
for (int i=1;i<=n;i++)scanf ("%lld",&tree.a[i]);
for (int i=1;i<n;i++){
    ll a,b;
    scanf("%lld%lld",&a,&b);
    add(a,b),add(b,a);
}

两遍DFS

第一遍DFS

这边DFS要处理出以下的信息:

  • 一个点的深度 $d$

  • 一个点的重儿子 $s$

  • 一个点的父亲 $f$

  • 一个点的子树大小(含自己) $siz$

代码:

inline void dfs1(int x){
    siz[x]=1;//初始大小是1(只有自己)
    for (int i=beg[x];i;i=nxt[i]){//访问x的所有出边
        if (to[i]==fa[x])continue;//到父亲的边,不考虑
        fa[to[i]]=x;//访问到的节点的父亲是x
        dep[to[i]]=dep[x]+1;//深度是x+1
        dfs1(to[i]);//向访问到的节点DFS
        siz[x]+=siz[to[i]];//加上儿子的子树大小
        if (siz[to[i]]>siz[son[x]])son[x]=to[i];//更新重儿子的信息
    }
}

第二遍DFS

在第一遍DFS的基础上,我们现在知道了每个点的子树大小(包括自己),重儿子等信息。

接下来就是树链剖分的核心了:

把一棵树按照轻重边剖分成若干条链,剖分的过程就是第二遍DFS

至于剖分的原因,后面在证明复杂度的时候会说

具体实现:

我们对每个点重标号,使得一条重链上的点的标号是连续的,然后对重标号后的点建线段树

第二遍DFS需要处理这些内容:

  • 记录下每个点的新标号

  • 把这个点的值赋到新标号上(之后建线段树要用)

  • 记录下每个点所在的重链的顶端

代码:

inline void dfs2(int x,int y){
    top[x]=y;id[x]=++tot;num[tot]=x;//记录下x所在的重链的顶端y,同时为新标号赋值
    if (son[x])dfs2(son[x],y);//优先DFSx的重儿子,这样保证一条重链上的点的标号是连续的
    for (int i=beg[x];i;i=nxt[i]){
        if (to[i]==fa[x]||to[i]==son[x])continue;
        dfs2(to[i],to[i]);//DFS剩下的轻儿子
    }
}

处理

敲黑板:重点来了!也许你不清楚前文提到的两边DFS的意义,这里会有解释

我们进行了第二遍DFS之后,得到了一下的结果:

  • 由于我们DFS时优先考虑重儿子,这样每条重链上的点的新标号是连续的

  • 由于是DFS,每棵子树的新标号是连续的

链的查询/修改

有了上文,本来链上的问题变成了统计若干条重链加轻边的问题。

查询和修改的思路实际上很相似,都是两个点分别向上跳,直到在同一条重链上为止,最后直接统计一次重链上两个点之间的和

既然重链上的新标号的连续的,那么我们就可以用线段树维护每条重链的和,这样重链的查询就是 $O(logn)$ 的了。

至于轻边,我们可以直接暴力累加,但由于我们人为定义了每个轻儿子是一条长度为1的重链,所以实际上写代码的时候可以和重链的查询合并。

修改的代码:

void chge(int x,int y,int k){
    k%=mod;
    while (top[x]!=top[y]){//使用类似倍增求LCA的思想,每次找深度大的点往上跳
        if (dep[top[x]]<dep[top[y]])swap(x,y);//保证x的深度大
        tree.change(1,id[top[x]],id[x],1,n,k);//直接统计x所在的重链的和,这条重链的起点是x所在的重链的顶端的新标号,终点是x的新标号
        x=fa[top[x]];//直接跳到x所在重链顶端的父亲,保证不重复统计
    }//把两个点跳到同一条重链上
    if (dep[x]<dep[y])swap(x,y);
    tree.change(1,id[y],id[x],1,n,k);//最后处理重链上两个点之间的部分
}

查询的代码:

int query(int x,int y){
    int ans=0;
    while (top[x]!=top[y]){
        if (dep[top[x]]<dep[top[y]])swap(x,y);//同样是选深度大的向上跳
        ans=(ans+tree.ask(1,id[top[x]],id[x],1,n))%mod;
        x=fa[top[x]];
    }
    if (dep[x]<dep[y])swap(x,y);
    ans=(ans+tree.ask(1,id[y],id[x],1,n))%mod;//也是最后处理同一条重链上x和y之间的部分
    return ans;
}

实际上很像,只是把线段树的修改换成了查询而已。

子树的修改/查询

既然有了一棵子树的新标号的连续的保证,那么原来的一棵子树实际上对应着线段树中的一段连续区间。

那么直接线段树区间修改/查询就好。而且由于是连续区间,那么 $x$ 的子树的起点必然是 $x$ 的新标号,终点必然是 $x$ 的新标号 $+x$ 的子树大小再 $-1$ 。

那么代码就很简单了:

修改:

tree.change(1,id[x],id[x]+siz[x]-1,1,n,k)

查询:

tree.ask(1,id[x],id[x]+siz[x]-1,1,n)

复杂度简单证明

当树为二叉树的时候,深度最大。

一棵树总共有 $n$ 个点,由于一个点 $x$ 的重儿子起码占了 $x$ 的子树大小的一半,那么每次递归下去节点个数都减半,很明显 $logn$ 次就能结束。

所以重链的数量 $\leq logn$ 。

又因为每两条重链之间必然由轻边分割(不然就连成一条重链了),所以轻边的数量同样 $\leq logn$ 。

于是树链剖分做到了对于一条链上的查询/修改, $O(log^2n)$ 的复杂度。

很是优秀了。

完整代码

实际上讲到这里应该都会写了吧(雾,但为了讲清楚一些细节还是给一下吧

#include <bits/stdc++.h>
using namespace std;
const int N=100005;
#define ll long long
int dep[N],siz[N],fa[N],z[N],to[N<<1],beg[N<<1],nxt[N<<1],top[N<<2],bian,son[N<<1],id[N<<1],tot,n,m,r,mod,num[N];
struct Tree{
    ll ans[N<<2],tag[N<<2],a[N];
    inline ll lson(ll p){return p<<1;}
    inline ll rson(ll p){return (p<<1)|1;}
    inline void push_up(ll p){ans[p]=(ans[lson(p)]+ans[rson(p)])%mod;}
    inline void build(ll p,ll l,ll r){
        if (l==r){ans[p]=a[num[l]];return ;}
        ll mid=(l+r)>>1;
        build(lson(p),l,mid);
        build(rson(p),mid+1,r);
        push_up(p);
        tag[p]=0;
    }
    inline void lazy_tag(ll p,ll l,ll r,ll k){ans[p]=(ans[p]+(r-l+1)*k)%mod,tag[p]=(tag[p]+k)%mod;}
    inline void push_down(ll p,ll l,ll r){
        ll mid=(l+r)>>1;
        lazy_tag(lson(p),l,mid,tag[p]);
        lazy_tag(rson(p),mid+1,r,tag[p]);
        tag[p]=0;
    }
    inline void change(ll p,ll nl,ll nr,ll l,ll r,ll k){//nl,nr->changing l,changing r;l,r->visiting l,visiting r
        if (nl<=l&&nr>=r){ans[p]=(ans[p]+(r-l+1)*k)%mod,tag[p]=(tag[p]+k)%mod;return ;}
        ll mid=(l+r)>>1;
        push_down(p,l,r);
        if (nl<=mid)change(lson(p),nl,nr,l,mid,k);
        if (nr>mid)change(rson(p),nl,nr,mid+1,r,k);
        push_up(p);
    }
    inline ll ask(ll p,ll nl,ll nr,ll l,ll r){
        if (nl<=l&&nr>=r)return ans[p];
        ll mid=(l+r)>>1,res=0;
        push_down(p,l,r);
        if (nl<=mid)res=(res+ask(lson(p),nl,nr,l,mid))%mod;
        if (nr>mid)res=(res+ask(rson(p),nl,nr,mid+1,r))%mod;
        return res;
    }
}tree;//之前封装好的线段树
inline void add(int u,int v){//链式前向星加边
    to[++bian]=v;
    nxt[bian]=beg[u];
    beg[u]=bian;
}
inline void dfs1(int x){//第一遍dfs,之前已经详细讲过
    siz[x]=1;
    for (int i=beg[x];i;i=nxt[i]){
        if (to[i]==fa[x])continue;
        fa[to[i]]=x;
        dep[to[i]]=dep[x]+1;
        dfs1(to[i]);
        siz[x]+=siz[to[i]];
        if (siz[to[i]]>siz[son[x]])son[x]=to[i];
    }
}
inline void dfs2(int x,int y){//同上
    top[x]=y;id[x]=++tot;num[tot]=x;
    if (son[x])dfs2(son[x],y);
    for (int i=beg[x];i;i=nxt[i]){
        if (to[i]==fa[x]||to[i]==son[x])continue;
        dfs2(to[i],to[i]);
    }
}
int query(int x,int y){//链的查询,也讲过了
    int ans=0;
    while (top[x]!=top[y]){
        if (dep[top[x]]<dep[top[y]])swap(x,y);
        ans=(ans+tree.ask(1,id[top[x]],id[x],1,n))%mod;
        x=fa[top[x]];
    }
    if (dep[x]<dep[y])swap(x,y);
    ans=(ans+tree.ask(1,id[y],id[x],1,n))%mod;
    return ans;
}
void chge(int x,int y,int k){//链的修改
    k%=mod;
    while (top[x]!=top[y]){
        if (dep[top[x]]<dep[top[y]])swap(x,y);
        tree.change(1,id[top[x]],id[x],1,n,k);
        x=fa[top[x]];
    }
    if (dep[x]<dep[y])swap(x,y);
    tree.change(1,id[y],id[x],1,n,k);
}
int main(){
    scanf ("%lld%lld%lld%lld",&n,&m,&r,&mod);fa[r]=0,dep[r]=1;
    for (int i=1;i<=n;i++)scanf ("%lld",&tree.a[i]);
    for (int i=1;i<n;i++){
        ll a,b;
        scanf("%lld%lld",&a,&b);
        add(a,b),add(b,a);
    }
    dfs1(r);//从根节点开始
    dfs2(r,r);//同上,根节点必然是根所在的重链的顶端
    tree.build(1,1,n);//对于我们重标号之后的数组,建线段树
    while(m--){
        ll flag,x,y,k;
        scanf ("%lld",&flag);
        switch (flag){
            case 1:scanf ("%lld%lld%lld",&x,&y,&k);chge(x,y,k);break;//链上修改
            case 2:scanf ("%lld%lld",&x,&y);printf("%lld\n",query(x,y));break;//链上查询
            case 3:scanf ("%lld%lld",&x,&k);tree.change(1,id[x],id[x]+siz[x]-1,1,n,k);break;//子树修改
            case 4:scanf ("%lld",&x);printf("%lld\n",tree.ask(1,id[x],id[x]+siz[x]-1,1,n));//子树查询
        }
    }
}