浅析线段树实现
lateworker · · 算法·理论
前言
相信前来阅读本文的人,多少对线段树有些了解,故在此不再赘述线段树的原理。本文将聚焦于线段树的实现方法,以及代码复杂度的优化。
[2025.08.02] 现在应该还缺几张图,晚点再补吧。
普通线段树(zkw 线段树)
例题:区间加区间求和线段树
zkw 线段树是一种小常数的非递归式线段树,完全可以代替非动态开点的普通线段树。截至本文撰写之日,虽然已有一些博客详细介绍了 zkw 线段树,但他们的实现都不优。我们可以在修改之前先 pushdown,从而去标记永久化。我们还可以使用左闭右开的神奇写法,将其空间复杂度优化至
结点结构设计
记线段树维护的数组为
考虑一个线段树上结点
由于要做懒标记,它还要维护 tag[u] += dx,就相当于告诉程序,
然而只维护上述两个信息是不够的。我们还是考虑结点
不难发现对于每个
在传统递归线段树的实现中,区间左右端点
至此,线段树结点的双半群结构已经完备。
辅助函数
在实现 update 和 query 函数前,不妨先实现几个辅助函数。
线段树维护信息时,需要考虑三种 apply 的情况:
-
tag[u]\rightarrow tag[v] -
tag[u]\rightarrow sum[v] -
sum[u]\leftarrow sum[v]
其中
因此,我们可以实现一个 modify(u, val) 用来 apply 前两种情况;实现一个 pushup(u) 用来 apply 第三种情况。有了这两个函数,pushdown(u) 的实现便是平凡的,调用 modify 函数即可。
intl st[N * 2 + 10], tg[N * 2 + 10], sz[N * 2 + 10];
// st:区间和;tg:懒标记;sz:区间长度;
void pushup(int u) {
if (!u) return; // 如果你不在意 0 号点的值,可以不写
st[u] = st[u << 1] + st[u << 1 | 1]; // sum[u] <-- sum[v]
}
void modify(int u, intl val) { // val 是一种特殊的 tag
st[u] += sz[u] * val; // sum[u] <-- val
tg[u] += val; // tag[u] <-- val
}
void pushdown(int u) {
if (tg[u]) {
// (sum[v], tag[v]) <--- tag[u]
modify(u << 1, tg[u]);
modify(u << 1 | 1, tg[u]);
tg[u] = 0;
}
}
建树
zkw 线段树的精髓在于
由于题目输入了原始数组,我们需要先给所有叶结点的
在建树时,我们还要预处理出
// st[u]:上文 siz[u]
// tg[u]:上文 tag[u],这里未出现
// sz[u]:上文 siz[u]
int main() {
// 省略其它输入
for (int i = 1; i <= n; i++) {
cin >> st[i + n]; // 直接把初始值放到叶结点上
sz[i + n] = 1;
}
// 自底向上合并,n - 1 为第一个非叶结点
for (int i = n - 1; i; i--) {
st[i] = st[i << 1] + st[i << 1 | 1];
sz[i] = sz[i << 1] + sz[i << 1 | 1];
}
}
区间查询
如果说辅助函数还很普通,那么接下来要介绍的区间查询和修改,就是 zkw 线段树的神奇之处了。
区间查询的第一步,还是要利用叶结点的偏移量 offset,精准定位两个“标兵结点”。同主流左开右开写法不同,我们采取左闭右开写法,即 l = l + offset, r = r + offset + 1。这种写法的优势在下文中有所提及,这里不再赘述。这种写法的算法流程如下:
- 如果
l 指向右子结点,就把这个结点算进答案里,然后l 自增。 - 如果
r 指向右子结点,就把这个结点的兄弟算进答案里,然后r 自减。 - 如果
l=r (跳到一起了),退出循环。 - 两个指针同时上移一层,循环上述过程。
我们将“需要统计的结点”简称为“目标结点”。可以证明:
- 如果
l 指向左子结点,这一层必定不存在目标节点,r 同理,因此直接上移一层即可。 - 如果
l 指向右子结点,下一个目标结点必然不是l 的祖先,因此l 需要自增,以跳到右边的树上。 - 如果
r 指向右子节点,其自减与否并不会影响其上移一层之后的位置。代码实现中,我们先让它自减跳到目标结点上,然后直接统计答案。
读者或许好奇,左开右开的写法明明更简单,为什么不采用呢?
左闭右开写法除了能省掉一个“左开”带来的虚点的空间开销外,实际上它还有一个神奇的 trick:叶结点 offset 不需要是 2 的幂次。
我们只需要把数组长度
这个结论的证明比较复杂,并且两个指针在树上的路径也很神奇。这里放个图给各位感受一下,也可以从中观察到一些特殊的 Case。
最后,区间查询是需要 pushdown 的。这个不难,我们把它放到区间修改的部分讲。
// 上文已经讲过算法流程,这里懒得写注释了
int query(int l, int r) {
l += n, r += n + 1;
for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i); // 下传标记
int res = 0;
for (; l < r; l >>= 1, r >>= 1) {
if (l & 1) res += st[l++];
if (r & 1) res += st[--r];
} return res;
}
区间修改
先讲 pushdown。
我们已经实现了 pushdown 函数,用来下传某个点的标记。然而,在区间修改之前,我们需要先把根到左右指针路径上所有点的懒标记下传下去,这样在后续访问或修改的时候才是对的。切记:我们需要保证所有需要访问的点已经被 pushdown,在 zkw 线段树上就是叶子到根路径上的点。另外,其实多 pushdown 几次答案一定不会错,假设我们在每次操作中,把树上所有点都 pushdown 一遍,这就等于没有懒标记。
实现上,zkw 线段树的叶结点位于第 u >> 1, u >> 2, ..., u >> (__lg(n)+1),反向遍历并 pushdown 即可。
// 假设两个指向叶结点的指针分别为 l, r
for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i);
确保标记都下传之后,我们就可以直接对所有目标结点做修改了(目标结点定义同上)。和 query 的流程基本一样,我们用 u = v = 0;,当
另外,
void update(int l, int r, intl val) {
l += n, r += n + 1;
for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i);
for (int u = 0, v = 0; l < r; l >>= 1, r >>= 1) {
// 遇到目标结点再赋值
if (l & 1) u = l, modify(l++, val);
if (r & 1) v = r, modify(--r, val);
// 以下 pushup 建议照抄,这是目前最简单的写法,注意要先上移再 pushup
do pushup(u >>= 1); while (l == r && u);
do pushup(v >>= 1); while (l == r && v);
}
}
完整实现
:::info[【模板】线段树 1]
#define ffopen(s) \
cin.tie(0)->sync_with_stdio(0); \
if (*#s) freopen(#s ".in", "r", stdin); \
if (*#s) freopen(#s ".out", "w", stdout); \
/* 神奇的 freopen 手法 */
#include <bits/stdc++.h>
using namespace std;
using intl = long long;
const int N = 100000;
int n, m;
intl st[N * 2 + 10], tg[N * 2 + 10], sz[N * 2 + 10];
void pushup(int u) {
if (!u) return;
st[u] = st[u << 1] + st[u << 1 | 1];
}
void modify(int u, intl val) {
st[u] += sz[u] * val;
tg[u] += val;
}
void pushdown(int u) {
if (tg[u]) {
modify(u << 1, tg[u]);
modify(u << 1 | 1, tg[u]);
tg[u] = 0;
}
}
void update(int l, int r, intl val) {
l += n, r += n + 1;
for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i);
for (int u = 0, v = 0; l < r; l >>= 1, r >>= 1) {
if (l & 1) u = l, modify(l++, val);
if (r & 1) v = r, modify(--r, val);
do pushup(u >>= 1); while (l == r && u);
do pushup(v >>= 1); while (l == r && v);
}
}
intl query(int l, int r) {
l += n, r += n + 1;
for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i);
intl res = 0;
for (; l < r; l >>= 1, r >>= 1) {
if (l & 1) res += st[l++];
if (r & 1) res += st[--r];
} return res;
}
int main() { ffopen();
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> st[i + n];
sz[i + n] = 1;
}
for (int i = n - 1; i; i--) {
st[i] = st[i << 1] + st[i << 1 | 1];
sz[i] = sz[i << 1] + sz[i << 1 | 1];
}
for (int i = 1; i <= m; i++) {
int op, x, y; intl k;
cin >> op >> x >> y;
if (op == 1) { cin >> k; update(x, y, k); }
if (op == 2) { cout << query(x, y) << "\n"; }
}
return 0;
}
:::
zkw 线段树上二分
例题:区间加区间求和线段树二分
有些题目需要我们在值域上二分,然后用线段树的区间查询结果来判断当前位置是否合法。如果先二分再查询,这样的时间复杂度是双 log 的,往往不能接受。由于线段树本身具有二分结构,因此我们可以直接在线段树上二分,这样就可以做到单 log。
注意:如果要在 zkw 线段树上二分,偏移量 offset 必须补到 2 的幂次!未补齐的 zkw 线段树是不具有二分结构的!
int tn; // 在前面声明偏移量 tn
/*
... 省略其它代码 ...
*/
for (tn = 1; tn <= n + 1; tn <<= 1); // 补齐的代码很简单,一行搞定
线段树二分其实分为三个步骤:
- 在线段树上卡出二分的上下界。
- 一边遍历合法节点,一边记录前缀信息,找到一个包含了答案所在位置的线段树结点。
- 在该结点的子树内二分出答案。
在传统递归式写法中,这三个步骤可以一起完成。然而在 zkw 线段树中,我们需要把这三个步骤拆成两个函数来做。
本质来讲,我们可以把二分上下界当作区间查询。这样对应到线段树上,就是若干需要计入答案的结点(下文简称为“关键结点”)。线段树上二分要做的,无非是先找到答案所在的关键节点,然后在这个点对应的区间上二分。我们把前者和卡二分上下界一起在函数 lowerbs(l, r, x) 里实现,后者放到函数 fullbs(u, x, prf) 里实现。
lowerbs 函数
第一步,寻找关键结点。这是简单的,和上文区间操作的流程相同。
第二步,定位包含答案位置的关键结点。在本题中,操作 3 输入了一个数 fullbs(u, x, prf) 并返回答案即可。如果遍历完所有结点都没找到关键结点,就肯定无解了。
记得再访问结点之前先 pushdown 哦~
int lowerbs(int l, int r, intl x) {
static int rps[LN + 3]; // 用来存放右指针对应结点编号的栈,*rps 为栈顶编号
l += tn - 1, r += tn, *rps = 0; // 栈顶设为 0,清空栈
for (int i = __lg(tn); i; i--) pushdown(l >> i), pushdown(r >> i);
intl prf = 0; // 前缀答案(不包括当前 u 结点)
for (; l < r; l >>= 1, r >>= 1) {
if (l & 1) { // 直接处理左链
int u = l++;
intl now = prf + st[u]; // 包含当前结点的前缀和
// 如果前缀和大于等于目标,调用 fullbs 继续二分,否则更新前缀答案
if (now >= x) return fullbs(u, x, prf);
prf = now;
}
if (r & 1) rps[++*rps] = --r; // 把右链结点塞进栈里
}
for (; *rps; --*rps) { // 弹栈,处理右链
int u = rps[*rps];
intl now = prf + st[u];
if (now >= x) return fullbs(u, x, prf);
prf = now;
}
return r + 1; // 报告无解
}
fullbs 函数
我们考虑把递归式线段树二分写成循环版本即可。
int fullbs(int u, intl x, intl prf) {
while (u < tn) {
pushdown(u);
intl now = prf + st[u <<= 1];
if (now < x) prf = now, u ^= 1;
} return u - tn + 1;
}
完整实现
:::info[【模板】线段树二分 2]
#define ffopen(s) \
cin.tie(0)->sync_with_stdio(0); \
if (*#s) freopen(#s ".in", "r", stdin); \
if (*#s) freopen(#s ".out", "w", stdout); \
/**/
#include <bits/stdc++.h>
using namespace std;
using intl = long long;
const int N = 1000000, LN = __lg(N) + 1;
int n, tn, m, sz[N * 3 + 10];
intl st[N * 3 + 10], tg[N * 3 + 10];
void pushup(int u) { if (u) st[u] = st[u << 1] + st[u << 1 | 1]; }
void modify(int u, intl val) { st[u] += val * sz[u], tg[u] += val; }
void pushdown(int u) {
if (tg[u]) {
modify(u << 1, tg[u]);
modify(u << 1 | 1, tg[u]);
tg[u] = 0;
}
}
void update(int l, int r, intl val) {
l += tn - 1, r += tn;
for (int i = __lg(tn); i; i--) pushdown(l >> i), pushdown(r >> i);
for (int u = 0, v = 0; l < r; l >>= 1, r >>= 1) {
if (l & 1) u = l, modify(l++, val);
if (r & 1) v = r, modify(--r, val);
do pushup(u >>= 1); while (l == r && u);
do pushup(v >>= 1); while (l == r && v);
}
}
void update(int u, intl val) {
u += tn - 1;
for (int i = __lg(tn); i; i--) pushdown(u >> i);
st[u] = val; do pushup(u >>= 1); while (u);
}
intl query(int l, int r) {
l += tn - 1, r += tn;
for (int i = __lg(tn); i; i--) pushdown(l >> i), pushdown(r >> i);
intl res = 0;
for (; l < r; l >>= 1, r >>= 1) {
if (l & 1) res += st[l++];
if (r & 1) res += st[--r];
} return res;
}
int fullbs(int u, intl x, intl prf) {
while (u < tn) {
pushdown(u);
intl now = prf + st[u <<= 1];
if (now < x) prf = now, u ^= 1;
} return u - tn + 1;
}
int lowerbs(int l, int r, intl x) {
static int rps[LN + 3];
l += tn - 1, r += tn, *rps = 0;
for (int i = __lg(tn); i; i--) pushdown(l >> i), pushdown(r >> i);
intl prf = 0;
for (; l < r; l >>= 1, r >>= 1) {
if (l & 1) {
int u = l++;
intl now = prf + st[u];
if (now >= x) return fullbs(u, x, prf);
prf = now;
}
if (r & 1) rps[++*rps] = --r;
}
for (; *rps; --*rps) {
int u = rps[*rps];
intl now = prf + st[u];
if (now >= x) return fullbs(u, x, prf);
prf = now;
}
return r + 1;
}
int main() { ffopen();
cin >> n >> m;
for (tn = 1; tn <= n; tn <<= 1);
for (int i = 1; i <= n; i++) {
cin >> st[i + tn - 1];
sz[i + tn - 1] = 1;
}
for (int u = tn - 1; u; u--) pushup(u), sz[u] = sz[u << 1] + sz[u << 1 | 1];
for (int i = 1; i <= m; i++) {
int op, x, y; intl z;
cin >> op;
if (op == 1) { cin >> x >> y >> z; update(x, y, z); }
else if (op == 2) { cin >> x >> y; cout << query(x, y) << "\n"; }
else if (op == 3) { cin >> x >> y >> z; cout << lowerbs(x, y, z) << "\n"; }
else if (op == 4) { cin >> x >> z; update(x, z); }
}
return 0;
}
:::