如何写出简单又强势的线段树
XYstarabyss · · 算法·理论
许多算法的本质是统计。线段树用于统计,是沟通原数组与前缀和的桥梁。
——《统计的力量》清华大学-张昆玮
前言&省流
想必递归实现的线段树写法大家都知道吧,如果还不知道的话看这里。现在我结合自己的理解介绍一种靠 递推 而非递归写法的线段树(即 zkw 线段树),努力结合图像用更加通俗的语言来讲这一种方法,造福更多像我这样的蒟蒻。
下面用 【模板】线段树 2 展示这一种写法的优越性。
这是我之前写的递归版线段树评测记录:
这是我现在写的 zkw 线段树评测记录:
可以看到无论是码量、内存用量还是运行时长都具有相当的优越性 可谓是多快好省。
那么这么简单又强势的西格蒙特树该怎么写呢?
实现
一般我们都是这么有所建树的:
即令根节点为
但是我们也可以先假定最底层的
由于在满二叉树下最底层有
效果如下:
我们就会发现它长相极其混沌邪恶……
那么该怎么办呢?
注意到节点二储存的区间并不完整,且包含的区间分布在左右两端,所以只要不是查询区间
所以我们大胆地抛弃节点一和节点二,效果就变成了这样:
节点 线段森林。
建树代码如下:
void pushup(ll x){
if(!tag[x]) sgt[x] = (sgt[x << 1] + sgt[x << 1 | 1]) % mod;
}//在没有懒标记的情况下将值从下往上传(这个函数还能重复利用)
void build(ll l,ll r){//这个也能重复利用
l += n - 1,r += n - 1;//找到l,r对应的节点
while (l > 1){//建树
l >>= 1,r >>= 1;//往上找父亲节点的范围
fc(r,l,i) pushup(i);//然后对于区间内的父亲节点更新值
}
}
build(1,n);
虽然节点一和节点二说是被抛弃了,但是只要不用就行了,不用一些乱七八糟的判断。
对于懒标记的操作比较常规,和更新值的操作整合到一起了。
void retag(ll x,ll add,ll mul,ll len){//更新x节点的信息,同时传入它要加的数、要乘的数以及它所含区间长度
sgt[x] = (sgt[x] * mul + add * len) % mod;//不管怎么样先更新本节点的数值
if(x < n){//如果这个节点存的信息所含区间长度不为1
tag[x] = true;//打上标记,准备往下传
mtag[x] = mtag[x] * mul % mod;//更新乘标记
atag[x] = (atag[x] * mul + add) % mod;//更新加标记
}
}
然后是考虑如何把标记下放。对于
h = 32 - __builtin_clz(n);//这样建树理论上的最大树高
void pushdown(ll x){//对于x往上溯源,找到它所有的祖先,下传它们。
ll s(h),len(1 << (h - 1));//先直接找到最高处(那个节点甚至可能没意义)
x += n - 1;//找到x对应的那一个节点
while (s > 0){//遍历x的祖宗直到x的父亲节点
ll y(x >> s);//找到那个祖宗对应的节点
if(tag[y]){//如果那一个祖宗有标记
retag(y << 1,atag[y],mtag[y],len);//更新它左儿子
retag(y << 1 | 1,atag[y],mtag[y],len);//更新它右儿子
tag[y] = 0,atag[y] = 0,mtag[y] = 1;//还原标记
}
-- s,len >>= 1;//继续往下遍历x的祖宗
}
}
到这里前期各种函数终于都准备好了,接下来是区间修改的函数。怎么实现呢?
void change(ll l,ll r,ll add,ll mul){//区间修改,整合了加操作和乘操作
ll cl(l),cr(r),len(1);//保存l,r的值以便上传,同时保存当前l,r所能包含的区间长
pushdown(l),pushdown(r);//先下传
l += n - 1,r += n - 1;//找到l,r对应的节点
while (l <= r){
if(l & 1) retag(l ++,add,mul,len);
if(~ r & 1) retag(r --,add,mul,len);
l >>= 1,r >>= 1,len <<= 1;
//就是把散块做了,然后一直往上层递推,在l>r时停止。
//具体实现:l%2=1时说明在一个散块中,r%2=0时说明在一个散块中。散块大小显然为1。
}
build(cl,cl),build(cr,cr);//循环利用build函数来对l,r的祖宗节点更新
}
是不是有一点懵逼?那么我们拿最上面那个 森林树修改区间
现在你们大概知道散块是什么意思了吧,总之这样我们就以一种巧妙的方式执行了操作。
至于为什么
//就是把散块做了,然后一直往上层递推,在l>r时停止。
//具体实现:l%2=1时说明在一个散块中,r%2=0时说明在一个散块中。散块大小显然为1。
可以自己去推一下,从同一层节点编号连续、节点编号奇偶性与其位置的关系等性质去想。
那么区间询问的代码也差不多。
ll range_query(ll l,ll r){
ll ans(0);
pushdown(l),pushdown(r);//下传
l += n - 1,r += n - 1;//找到对应节点
while (l <= r){
if (l & 1) ans = (ans + sgt[l ++]) % mod;
if (~ r & 1) ans = (ans + sgt[r --]) % mod;
l >>= 1,r >>= 1;
}//和区间修改差不多
return ans % mod;//华丽输出
}
代码
时间复杂度和递归的版本一样,但是常数和码量都要小一点。
#include <bits/stdc++.h>
using namespace std;
#define f(n,m,i) for (register int i(n);i <= m;++ i)
#define fc(n,m,i) for (register int i(n);i >= m;-- i)
#define dbug(x) cerr<<(#x)<<':'<<x<<' ';
#define ent cerr<<'\n';
#define C ios::sync_with_stdio(false),cin.tie(0),cout.tie(0),cerr.tie(0);
#define ll long long
ll o,n,m,opt,x,y,z,h,sgt[2000005],atag[1000005],mtag[1000005];
const ll mod(571373);
bool tag[1000005];
void retag(ll x,ll add,ll mul,ll len){
sgt[x] = (sgt[x] * mul + add * len) % mod;
if(x < n){
tag[x] = true;
mtag[x] = mtag[x] * mul % mod;
atag[x] = (atag[x] * mul + add) % mod;
}
}
void pushup(ll x){
if(!tag[x]) sgt[x] = (sgt[x << 1] + sgt[x << 1 | 1]) % mod;
}
void pushdown(ll x){
ll s(h),len(1 << (h - 1));
x += n - 1;
while (s > 0){
ll y(x >> s);
if(tag[y]){
retag(y << 1,atag[y],mtag[y],len);
retag(y << 1 | 1,atag[y],mtag[y],len);
tag[y] = 0,atag[y] = 0,mtag[y] = 1;
}
-- s,len >>= 1;
}
}
void build(ll l,ll r){
l += n - 1,r += n - 1;
while (l > 1){
l >>= 1,r >>= 1;
fc(r,l,i) pushup(i);
}
}
void change(ll l,ll r,ll add,ll mul){
ll cl(l),cr(r),len(1);
pushdown(l),pushdown(r);
l += n - 1,r += n - 1;
while (l <= r){
if(l & 1) retag(l ++,add,mul,len);
if(~ r & 1) retag(r --,add,mul,len);
l >>= 1,r >>= 1,len <<= 1;
}
build(cl,cl),build(cr,cr);
}
ll range_query(ll l,ll r){
ll ans(0);
pushdown(l),pushdown(r);
l += n - 1,r += n - 1;
while (l <= r){
if (l & 1) ans = (ans + sgt[l ++]) % mod;
if (~ r & 1) ans = (ans + sgt[r --]) % mod;
l >>= 1,r >>= 1;
}
return ans % mod;
}
int main(){ C
cin >> n >> m >> o;
h = 32 - __builtin_clz(n);
f(1,n - 1,i) mtag[i] = 1;
f(1,n,i){
cin >> sgt[i + (n - 1)];
sgt[i + (n - 1)] %= mod;
}
build(1,n);
while (m --){
cin >> opt >> x >> y;
if(opt == 1) cin >> z,change(x,y,0,z % mod);
else if(opt == 2) cin >> z,change(x,y,z % mod,1);
else cout << range_query(x,y) << '\n';
}
return 0;
}
后记&鸣谢
出处: %%% 清华大学张昆玮(zkw)《统计的力量》,我居然之前都不知道!还是太菜了。
感谢 @喝水 所写的 文章,本文是在其基础上结合个人理解写的。你们可以从他的文章中学习如何写出任意实数叉线段树!
感谢 @sidekick257 扔的一份 参考代码 比我写得更好。
欢迎各位奆佬指出不足。