浅析线段树实现

· · 算法·理论

前言

相信前来阅读本文的人,多少对线段树有些了解,故在此不再赘述线段树的原理。本文将聚焦于线段树的实现方法,以及代码复杂度的优化。

[2025.08.02] 现在应该还缺几张图,晚点再补吧。

普通线段树(zkw 线段树)

例题:区间加区间求和线段树

zkw 线段树是一种小常数的非递归式线段树,完全可以代替非动态开点的普通线段树。截至本文撰写之日,虽然已有一些博客详细介绍了 zkw 线段树,但他们的实现都不优。我们可以在修改之前先 pushdown,从而去标记永久化。我们还可以使用左闭右开的神奇写法,将其空间复杂度优化至 O(2n),不需将数组长度向上补全至 2^k。如果你完全没听说过 zkw 线段树,建议先去阅读这篇文章:浅谈zkw线段树(简单了解单点修改区间查询的做法即可)

结点结构设计

记线段树维护的数组为 a[i], i\in[1,n]

考虑一个线段树上结点 u\leftarrow[l, r],它需要直接维护 sum[u]=\sum_{i=l}^{r}{a[i]}

由于要做懒标记,它还要维护 tag[u]=dx,表示每个 a[i],i\in[l,r] 都需要加上 dx。何为“需要加”呢?我们考虑手动模拟区间修改的过程。在递归过程中,我们不可能将修改操作 apply 到每个线段树结点上,因此,对于结点 u,我们让 tag[u] += dx,就相当于告诉程序,u 子树内的每个节点都需要修改。需要注意,在下传标记前,u 子树内,只有 u 自己的区间和 sum[u] 是正确的。如果要访问 u 的儿子和子孙,必须先用 tag[u] 去修正 u 子树内的值。

然而只维护上述两个信息是不够的。我们还是考虑结点 u\leftarrow[l,r],给它加上一个 dx,看看上述两个信息有何变化。

tag[u]\rightarrow tag[u]+dx sum[u]=\sum_{i=l}^{r}{a[i]}\rightarrow\sum_{i=l}^{r}{(a[i]+dx)}=sum[u]+dx\times(r-l+1)

不难发现对于每个 u,还要维护对应的区间长度 r - l + 1

在传统递归线段树的实现中,区间左右端点 [l,r] 可以随着递归自然得到,这才不用刻意记录。但实际上它也是结点必须维护的信息之一。

至此,线段树结点的双半群结构已经完备。

辅助函数

在实现 update 和 query 函数前,不妨先实现几个辅助函数。

线段树维护信息时,需要考虑三种 apply 的情况:

  1. tag[u]\rightarrow tag[v]
  2. tag[u]\rightarrow sum[v]
  3. sum[u]\leftarrow sum[v]

其中 uv 的父结点。而修改操作中的 dx 可以视为一种特殊的 tag[u]

因此,我们可以实现一个 modify(u, val) 用来 apply 前两种情况;实现一个 pushup(u) 用来 apply 第三种情况。有了这两个函数,pushdown(u) 的实现便是平凡的,调用 modify 函数即可。

intl st[N * 2 + 10], tg[N * 2 + 10], sz[N * 2 + 10];
// st:区间和;tg:懒标记;sz:区间长度;
void pushup(int u) {
    if (!u) return; // 如果你不在意 0 号点的值,可以不写
    st[u] = st[u << 1] + st[u << 1 | 1]; // sum[u] <-- sum[v]
}
void modify(int u, intl val) { // val 是一种特殊的 tag
    st[u] += sz[u] * val; // sum[u] <-- val
    tg[u] += val;         // tag[u] <-- val
}
void pushdown(int u) {
    if (tg[u]) {
        // (sum[v], tag[v]) <--- tag[u]
        modify(u << 1, tg[u]);
        modify(u << 1 | 1, tg[u]);
        tg[u] = 0;
    }
}

建树

zkw 线段树的精髓在于 O(1) 定位叶结点,因此我们需要一个叶结点偏移量 offset。但由于实现的特殊性,我们并不需要特意地将 offset 补到 2 的幂次,直接使用数组长度 n 作为偏移量即可。(详见下文)

由于题目输入了原始数组,我们需要先给所有叶结点的 sum[u] 赋值,然后整棵树 pushup 即可。叶结点的树上编号为 u=i+n,i\in[1,n]

在建树时,我们还要预处理出 u 对应的区间长度 siz[u],可以先让叶结点的 siz[u]=1,然后自底向上合并即可。tag[u] 初始为 0,无需特殊处理。

// st[u]:上文 siz[u]
// tg[u]:上文 tag[u],这里未出现
// sz[u]:上文 siz[u]
int main() {
    // 省略其它输入
    for (int i = 1; i <= n; i++) {
        cin >> st[i + n]; // 直接把初始值放到叶结点上
        sz[i + n] = 1;
    }
    // 自底向上合并,n - 1 为第一个非叶结点
    for (int i = n - 1; i; i--) {
        st[i] = st[i << 1] + st[i << 1 | 1];
        sz[i] = sz[i << 1] + sz[i << 1 | 1];
    }
}

区间查询

如果说辅助函数还很普通,那么接下来要介绍的区间查询和修改,就是 zkw 线段树的神奇之处了。

区间查询的第一步,还是要利用叶结点的偏移量 offset,精准定位两个“标兵结点”。同主流左开右开写法不同,我们采取左闭右开写法,即 l = l + offset, r = r + offset + 1。这种写法的优势在下文中有所提及,这里不再赘述。这种写法的算法流程如下:

  1. 如果 l 指向右子结点,就把这个结点算进答案里,然后 l 自增。
  2. 如果 r 指向右子结点,就把这个结点的兄弟算进答案里,然后 r 自减。
  3. 如果 l=r(跳到一起了),退出循环。
  4. 两个指针同时上移一层,循环上述过程。

我们将“需要统计的结点”简称为“目标结点”。可以证明:

  1. 如果 l 指向左子结点,这一层必定不存在目标节点,r 同理,因此直接上移一层即可。
  2. 如果 l 指向右子结点,下一个目标结点必然不是 l 的祖先,因此 l 需要自增,以跳到右边的树上。
  3. 如果 r 指向右子节点,其自减与否并不会影响其上移一层之后的位置。代码实现中,我们先让它自减跳到目标结点上,然后直接统计答案。

读者或许好奇,左开右开的写法明明更简单,为什么不采用呢?

左闭右开写法除了能省掉一个“左开”带来的虚点的空间开销外,实际上它还有一个神奇的 trick:叶结点 offset 不需要是 2 的幂次。

我们只需要把数组长度 n 当作 offset 即可,顺便也实现了 O(2n) 的优秀空间复杂度。

这个结论的证明比较复杂,并且两个指针在树上的路径也很神奇。这里放个图给各位感受一下,也可以从中观察到一些特殊的 Case。

最后,区间查询是需要 pushdown 的。这个不难,我们把它放到区间修改的部分讲。

// 上文已经讲过算法流程,这里懒得写注释了
int query(int l, int r) {
    l += n, r += n + 1;
    for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i); // 下传标记
    int res = 0;
    for (; l < r; l >>= 1, r >>= 1) {
        if (l & 1) res += st[l++];
        if (r & 1) res += st[--r];
    } return res;
}

区间修改

先讲 pushdown。

我们已经实现了 pushdown 函数,用来下传某个点的标记。然而,在区间修改之前,我们需要先把根到左右指针路径上所有点的懒标记下传下去,这样在后续访问或修改的时候才是对的。切记:我们需要保证所有需要访问的点已经被 pushdown,在 zkw 线段树上就是叶子到根路径上的点。另外,其实多 pushdown 几次答案一定不会错,假设我们在每次操作中,把树上所有点都 pushdown 一遍,这就等于没有懒标记。

实现上,zkw 线段树的叶结点位于第 \log_2{n}+1 层。因此叶子 u 到根的所有结点必然是 u >> 1, u >> 2, ..., u >> (__lg(n)+1),反向遍历并 pushdown 即可。

// 假设两个指向叶结点的指针分别为 l, r
for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i);

确保标记都下传之后,我们就可以直接对所有目标结点做修改了(目标结点定义同上)。和 query 的流程基本一样,我们用 l, r 两个指针定位叶结点,然后以同样的方法向上跳,然后用 modify 函数更新目标结点即可。不同之处在于,我们需要另开两个指针 u, v,用来处理 pushup。初始时 u = v = 0;,当 l, r 指针找到第一个目标结点后,u,v 才会被真正赋值并执行 pushup。因为我们不可以对于目标结点子树内的任何结点做 pushup,不难自证。

另外,l, r 指针相遇之后,u, v 需要继续向上一直 pushup 到根。但我们给出的参考实现中,用一种巧妙的方式避免了额外的代码讨论。

void update(int l, int r, intl val) {
    l += n, r += n + 1;
    for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i);
    for (int u = 0, v = 0; l < r; l >>= 1, r >>= 1) {
    // 遇到目标结点再赋值
        if (l & 1) u = l, modify(l++, val);
        if (r & 1) v = r, modify(--r, val);
    // 以下 pushup 建议照抄,这是目前最简单的写法,注意要先上移再 pushup
        do pushup(u >>= 1); while (l == r && u);
        do pushup(v >>= 1); while (l == r && v);
    } 
}

完整实现

:::info[【模板】线段树 1]

#define ffopen(s) \
cin.tie(0)->sync_with_stdio(0); \
if (*#s) freopen(#s ".in", "r", stdin); \
if (*#s) freopen(#s ".out", "w", stdout); \
/* 神奇的 freopen 手法 */
#include <bits/stdc++.h>
using namespace std;
using intl = long long;
const int N = 100000;
int n, m;
intl st[N * 2 + 10], tg[N * 2 + 10], sz[N * 2 + 10];
void pushup(int u) {
    if (!u) return;
    st[u] = st[u << 1] + st[u << 1 | 1];
}
void modify(int u, intl val) {
    st[u] += sz[u] * val;
    tg[u] += val;
}
void pushdown(int u) {
    if (tg[u]) {
        modify(u << 1, tg[u]);
        modify(u << 1 | 1, tg[u]);
        tg[u] = 0;
    }
}
void update(int l, int r, intl val) {
    l += n, r += n + 1;
    for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i);
    for (int u = 0, v = 0; l < r; l >>= 1, r >>= 1) {
        if (l & 1) u = l, modify(l++, val);
        if (r & 1) v = r, modify(--r, val);
        do pushup(u >>= 1); while (l == r && u);
        do pushup(v >>= 1); while (l == r && v);
    } 
}
intl query(int l, int r) {
    l += n, r += n + 1;
    for (int i = __lg(n) + 1; i; i--) pushdown(l >> i), pushdown(r >> i);
    intl res = 0;
    for (; l < r; l >>= 1, r >>= 1) {
        if (l & 1) res += st[l++];
        if (r & 1) res += st[--r];
    } return res;
}
int main() { ffopen();
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> st[i + n];
        sz[i + n] = 1;
    }
    for (int i = n - 1; i; i--) {
        st[i] = st[i << 1] + st[i << 1 | 1];
        sz[i] = sz[i << 1] + sz[i << 1 | 1];
    }
    for (int i = 1; i <= m; i++) {
        int op, x, y; intl k;
        cin >> op >> x >> y;
        if (op == 1) { cin >> k; update(x, y, k); }
        if (op == 2) { cout << query(x, y) << "\n"; }
    }
    return 0;
}

:::

zkw 线段树上二分

例题:区间加区间求和线段树二分

有些题目需要我们在值域上二分,然后用线段树的区间查询结果来判断当前位置是否合法。如果先二分再查询,这样的时间复杂度是双 log 的,往往不能接受。由于线段树本身具有二分结构,因此我们可以直接在线段树上二分,这样就可以做到单 log。

注意:如果要在 zkw 线段树上二分,偏移量 offset 必须补到 2 的幂次!未补齐的 zkw 线段树是不具有二分结构的!

int tn; // 在前面声明偏移量 tn
/*
... 省略其它代码 ...
*/
for (tn = 1; tn <= n + 1; tn <<= 1); // 补齐的代码很简单,一行搞定

线段树二分其实分为三个步骤:

  1. 在线段树上卡出二分的上下界。
  2. 一边遍历合法节点,一边记录前缀信息,找到一个包含了答案所在位置的线段树结点。
  3. 在该结点的子树内二分出答案。

在传统递归式写法中,这三个步骤可以一起完成。然而在 zkw 线段树中,我们需要把这三个步骤拆成两个函数来做。

本质来讲,我们可以把二分上下界当作区间查询。这样对应到线段树上,就是若干需要计入答案的结点(下文简称为“关键结点”)。线段树上二分要做的,无非是先找到答案所在的关键节点,然后在这个点对应的区间上二分。我们把前者和卡二分上下界一起在函数 lowerbs(l, r, x) 里实现,后者放到函数 fullbs(u, x, prf) 里实现。

lowerbs 函数

第一步,寻找关键结点。这是简单的,和上文区间操作的流程相同。

第二步,定位包含答案位置的关键结点。在本题中,操作 3 输入了一个数 x,我们需要找到第一个前缀和 \sum_{i=1}^{pos}{a[i]}\ge x 的位置 pos。其中 a[i] 是我们要维护的数组。因此,我们按照对应叶结点的下标,在树上从左到右遍历每个关键结点 u。如果前缀和大于等于目标,即 now=prf+st[u]\ge x,就说明答案一定在当前结点所对应的区间内,调用 fullbs(u, x, prf) 并返回答案即可。如果遍历完所有结点都没找到关键结点,就肯定无解了。

记得再访问结点之前先 pushdown 哦~

int lowerbs(int l, int r, intl x) {
    static int rps[LN + 3]; // 用来存放右指针对应结点编号的栈,*rps 为栈顶编号
    l += tn - 1, r += tn, *rps = 0; // 栈顶设为 0,清空栈
    for (int i = __lg(tn); i; i--) pushdown(l >> i), pushdown(r >> i);
    intl prf = 0; // 前缀答案(不包括当前 u 结点)
    for (; l < r; l >>= 1, r >>= 1) {
        if (l & 1) { // 直接处理左链
            int u = l++;
            intl now = prf + st[u]; // 包含当前结点的前缀和
        // 如果前缀和大于等于目标,调用 fullbs 继续二分,否则更新前缀答案
            if (now >= x) return fullbs(u, x, prf);
            prf = now;
        }
        if (r & 1) rps[++*rps] = --r; // 把右链结点塞进栈里
    }
    for (; *rps; --*rps) { // 弹栈,处理右链
        int u = rps[*rps];
        intl now = prf + st[u];
        if (now >= x) return fullbs(u, x, prf);
        prf = now;
    }
    return r + 1; // 报告无解
}

fullbs 函数

我们考虑把递归式线段树二分写成循环版本即可。

int fullbs(int u, intl x, intl prf) {
    while (u < tn) {
        pushdown(u);
        intl now = prf + st[u <<= 1];
        if (now < x) prf = now, u ^= 1;
    } return u - tn + 1;
}

完整实现

:::info[【模板】线段树二分 2]

#define ffopen(s) \
cin.tie(0)->sync_with_stdio(0); \
if (*#s) freopen(#s ".in", "r", stdin); \
if (*#s) freopen(#s ".out", "w", stdout); \
/**/
#include <bits/stdc++.h>
using namespace std;
using intl = long long;
const int N = 1000000, LN = __lg(N) + 1;
int n, tn, m, sz[N * 3 + 10];
intl st[N * 3 + 10], tg[N * 3 + 10];
void pushup(int u) { if (u) st[u] = st[u << 1] + st[u << 1 | 1]; }
void modify(int u, intl val) { st[u] += val * sz[u], tg[u] += val; }
void pushdown(int u) {
    if (tg[u]) {
        modify(u << 1, tg[u]);
        modify(u << 1 | 1, tg[u]);
        tg[u] = 0;
    }
}
void update(int l, int r, intl val) {
    l += tn - 1, r += tn;
    for (int i = __lg(tn); i; i--) pushdown(l >> i), pushdown(r >> i);
    for (int u = 0, v = 0; l < r; l >>= 1, r >>= 1) {
        if (l & 1) u = l, modify(l++, val);
        if (r & 1) v = r, modify(--r, val);
        do pushup(u >>= 1); while (l == r && u);
        do pushup(v >>= 1); while (l == r && v);
    }
}
void update(int u, intl val) {
    u += tn - 1;
    for (int i = __lg(tn); i; i--) pushdown(u >> i);
    st[u] = val; do pushup(u >>= 1); while (u);
}
intl query(int l, int r) {
    l += tn - 1, r += tn;
    for (int i = __lg(tn); i; i--) pushdown(l >> i), pushdown(r >> i);
    intl res = 0;
    for (; l < r; l >>= 1, r >>= 1) {
        if (l & 1) res += st[l++];
        if (r & 1) res += st[--r];
    } return res;
}
int fullbs(int u, intl x, intl prf) {
    while (u < tn) {
        pushdown(u);
        intl now = prf + st[u <<= 1];
        if (now < x) prf = now, u ^= 1;
    } return u - tn + 1;
}
int lowerbs(int l, int r, intl x) {
    static int rps[LN + 3];
    l += tn - 1, r += tn, *rps = 0;
    for (int i = __lg(tn); i; i--) pushdown(l >> i), pushdown(r >> i);
    intl prf = 0;
    for (; l < r; l >>= 1, r >>= 1) {
        if (l & 1) {
            int u = l++;
            intl now = prf + st[u];
            if (now >= x) return fullbs(u, x, prf);
            prf = now;
        }
        if (r & 1) rps[++*rps] = --r;
    }
    for (; *rps; --*rps) {
        int u = rps[*rps];
        intl now = prf + st[u];
        if (now >= x) return fullbs(u, x, prf);
        prf = now;
    }
    return r + 1;
}
int main() { ffopen();
    cin >> n >> m;
    for (tn = 1; tn <= n; tn <<= 1);
    for (int i = 1; i <= n; i++) {
        cin >> st[i + tn - 1];
        sz[i + tn - 1] = 1;
    }
    for (int u = tn - 1; u; u--) pushup(u), sz[u] = sz[u << 1] + sz[u << 1 | 1];
    for (int i = 1; i <= m; i++) {
        int op, x, y; intl z;
        cin >> op;
        if (op == 1) { cin >> x >> y >> z; update(x, y, z); }
        else if (op == 2) { cin >> x >> y; cout << query(x, y) << "\n"; }
        else if (op == 3) { cin >> x >> y >> z; cout << lowerbs(x, y, z) << "\n"; }
        else if (op == 4) { cin >> x >> z; update(x, z); }
    }
    return 0;
}

:::