如何写出简单又强势的线段树

· · 算法·理论

许多算法的本质是统计。线段树用于统计,是沟通原数组与前缀和的桥梁。

——《统计的力量》清华大学-张昆玮

前言&省流

想必递归实现的线段树写法大家都知道吧,如果还不知道的话看这里。现在我结合自己的理解介绍一种靠 递推 而非递归写法的线段树(即 zkw 线段树),努力结合图像用更加通俗的语言来讲这一种方法,造福更多像我这样的蒟蒻。

下面用 【模板】线段树 2 展示这一种写法的优越性。

这是我之前写的递归版线段树评测记录:

这是我现在写的 zkw 线段树评测记录:

可以看到无论是码量、内存用量还是运行时长都具有相当的优越性 可谓是多快好省

那么这么简单又强势的西格蒙特树该怎么写呢?

实现

一般我们都是这么有所建树的:

即令根节点为 1,然后据此递归往下建树。

但是我们也可以先假定最底层的 n 个节点的编号,这样更加方便递推。

由于在满二叉树下最底层有 2^h 个节点,其余层共 2^0 + 2^1 + 2^2 + \dots + 2^{h - 1},即 2^h - 1,所以我们只需把 i 处的数值放在 n - 1 + i 处前面就有足够的空间建树。然后节点 i 的父亲节点就是节点 i >> 1

效果如下:

我们就会发现它长相极其混沌邪恶……

那么该怎么办呢?

注意到节点二储存的区间并不完整,且包含的区间分布在左右两端,所以只要不是查询区间 [1,n] 就用不着。

所以我们大胆地抛弃节点一和节点二,效果就变成了这样:

节点 3,4,5 并没有父亲节点,所以线段树变成了一堆散块线段森林

建树代码如下:

void pushup(ll x){
    if(!tag[x])    sgt[x] = (sgt[x << 1] + sgt[x << 1 | 1]) % mod;
}//在没有懒标记的情况下将值从下往上传(这个函数还能重复利用)
void build(ll l,ll r){//这个也能重复利用
    l += n - 1,r += n - 1;//找到l,r对应的节点
    while (l > 1){//建树
        l >>= 1,r >>= 1;//往上找父亲节点的范围
        fc(r,l,i)    pushup(i);//然后对于区间内的父亲节点更新值
    }
}
build(1,n);

虽然节点一和节点二说是被抛弃了,但是只要不用就行了,不用一些乱七八糟的判断。

对于懒标记的操作比较常规,和更新值的操作整合到一起了。

void retag(ll x,ll add,ll mul,ll len){//更新x节点的信息,同时传入它要加的数、要乘的数以及它所含区间长度
    sgt[x] = (sgt[x] * mul + add * len) % mod;//不管怎么样先更新本节点的数值
    if(x < n){//如果这个节点存的信息所含区间长度不为1
        tag[x] = true;//打上标记,准备往下传
        mtag[x] = mtag[x] * mul % mod;//更新乘标记
        atag[x] = (atag[x] * mul + add) % mod;//更新加标记
    }
}

然后是考虑如何把标记下放。对于 [l,r] 的区间操作,我们只需要对于 lr 往上溯源,找到它们所有的祖先,下传它们。

h = 32 - __builtin_clz(n);//这样建树理论上的最大树高
void pushdown(ll x){//对于x往上溯源,找到它所有的祖先,下传它们。
    ll s(h),len(1 << (h - 1));//先直接找到最高处(那个节点甚至可能没意义)
    x += n - 1;//找到x对应的那一个节点
    while (s > 0){//遍历x的祖宗直到x的父亲节点
        ll y(x >> s);//找到那个祖宗对应的节点
        if(tag[y]){//如果那一个祖宗有标记
            retag(y << 1,atag[y],mtag[y],len);//更新它左儿子
            retag(y << 1 | 1,atag[y],mtag[y],len);//更新它右儿子
            tag[y] = 0,atag[y] = 0,mtag[y] = 1;//还原标记
        }
        -- s,len >>= 1;//继续往下遍历x的祖宗
    }
}

到这里前期各种函数终于都准备好了,接下来是区间修改的函数。怎么实现呢?

void change(ll l,ll r,ll add,ll mul){//区间修改,整合了加操作和乘操作
    ll cl(l),cr(r),len(1);//保存l,r的值以便上传,同时保存当前l,r所能包含的区间长
    pushdown(l),pushdown(r);//先下传
    l += n - 1,r += n - 1;//找到l,r对应的节点
    while (l <= r){
        if(l & 1)   retag(l ++,add,mul,len);
        if(~ r & 1)   retag(r --,add,mul,len);
        l >>= 1,r >>= 1,len <<= 1;
        //就是把散块做了,然后一直往上层递推,在l>r时停止。
        //具体实现:l%2=1时说明在一个散块中,r%2=0时说明在一个散块中。散块大小显然为1。
    }
    build(cl,cl),build(cr,cr);//循环利用build函数来对l,r的祖宗节点更新
}

是不是有一点懵逼?那么我们拿最上面那个 n = 9 的线段森林树修改区间 [2,7] 演示一下。

现在你们大概知道散块是什么意思了吧,总之这样我们就以一种巧妙的方式执行了操作。

至于为什么

//就是把散块做了,然后一直往上层递推,在l>r时停止。
//具体实现:l%2=1时说明在一个散块中,r%2=0时说明在一个散块中。散块大小显然为1。

可以自己去推一下,从同一层节点编号连续、节点编号奇偶性与其位置的关系等性质去想。

那么区间询问的代码也差不多。

ll range_query(ll l,ll r){
    ll ans(0);
    pushdown(l),pushdown(r);//下传
    l += n - 1,r += n - 1;//找到对应节点
    while (l <= r){
        if (l & 1)    ans = (ans + sgt[l ++]) % mod;
        if (~ r & 1)    ans = (ans + sgt[r --]) % mod;
        l >>= 1,r >>= 1;
    }//和区间修改差不多
    return ans % mod;//华丽输出
}

代码

时间复杂度和递归的版本一样,但是常数和码量都要小一点。

#include <bits/stdc++.h>
using namespace std;
#define f(n,m,i) for (register int i(n);i <= m;++ i)
#define fc(n,m,i) for (register int i(n);i >= m;-- i)
#define dbug(x) cerr<<(#x)<<':'<<x<<' ';
#define ent cerr<<'\n';
#define C ios::sync_with_stdio(false),cin.tie(0),cout.tie(0),cerr.tie(0);
#define ll long long
ll o,n,m,opt,x,y,z,h,sgt[2000005],atag[1000005],mtag[1000005];
const ll mod(571373);
bool tag[1000005];
void retag(ll x,ll add,ll mul,ll len){
    sgt[x] = (sgt[x] * mul + add * len) % mod;
    if(x < n){
        tag[x] = true;
        mtag[x] = mtag[x] * mul % mod;
        atag[x] = (atag[x] * mul + add) % mod;
    }
}
void pushup(ll x){
    if(!tag[x])    sgt[x] = (sgt[x << 1] + sgt[x << 1 | 1]) % mod;
}
void pushdown(ll x){
    ll s(h),len(1 << (h - 1));
    x += n - 1;
    while (s > 0){
        ll y(x >> s);
        if(tag[y]){
            retag(y << 1,atag[y],mtag[y],len);
            retag(y << 1 | 1,atag[y],mtag[y],len);
            tag[y] = 0,atag[y] = 0,mtag[y] = 1;
        }
        -- s,len >>= 1;
    }
}
void build(ll l,ll r){
    l += n - 1,r += n - 1;
    while (l > 1){
        l >>= 1,r >>= 1;
        fc(r,l,i)    pushup(i);
    }
}
void change(ll l,ll r,ll add,ll mul){
    ll cl(l),cr(r),len(1);
    pushdown(l),pushdown(r);
    l += n - 1,r += n - 1;
    while (l <= r){
        if(l & 1)   retag(l ++,add,mul,len);
        if(~ r & 1)   retag(r --,add,mul,len);
        l >>= 1,r >>= 1,len <<= 1;
    }
    build(cl,cl),build(cr,cr);
}
ll range_query(ll l,ll r){
    ll ans(0);
    pushdown(l),pushdown(r);
    l += n - 1,r += n - 1;
    while (l <= r){
        if (l & 1)    ans = (ans + sgt[l ++]) % mod;
        if (~ r & 1)    ans = (ans + sgt[r --]) % mod;
        l >>= 1,r >>= 1;
    }
    return ans % mod;
}
int main(){ C
    cin >> n >> m >> o;
    h = 32 - __builtin_clz(n);
    f(1,n - 1,i)    mtag[i] = 1;
    f(1,n,i){
        cin >> sgt[i + (n - 1)];
        sgt[i + (n - 1)] %= mod;
    }
    build(1,n);
    while (m --){
        cin >> opt >> x >> y;
        if(opt == 1)    cin >> z,change(x,y,0,z % mod);
        else if(opt == 2)   cin >> z,change(x,y,z % mod,1);
        else    cout << range_query(x,y) << '\n';
    }
    return 0;
}

后记&鸣谢

出处: %%% 清华大学张昆玮(zkw)《统计的力量》,我居然之前都不知道!还是太菜了。

感谢 @喝水 所写的 文章,本文是在其基础上结合个人理解写的。你们可以从他的文章中学习如何写出任意实数叉线段树!

感谢 @sidekick257 扔的一份 参考代码 比我写得更好

欢迎各位奆佬指出不足。