树状数组小记

· · 算法·理论

\large\color{00aacd}\textbf{树状数组(Binary Index Tree)}

发现自己之前零零碎碎地写过一些树状数组的应用,所以写这一篇文章来整合一下。

树状数组是一种支持单点修改,区间查询的精巧的数据结构,通常用于维护满足结合律可差分的运算和信息。又称二叉索引树(Binary Index Tree)、Fenwick Tree。

\color{00cd00}\text{单点修改,区间查询}

下面这张图展示了树状数组的原理(来源:OI-Wiki)。

其中 c_x 表示以 x 为右端点,长度为 {\rm lowbit}(x) 的区间的和。

例如,$10$ 在二进制表示下为 $10\underset{\blacktriangle}{\bf1}0$,加粗的就是最低位的 $1$,它的权值是 $2$,因此 $\rm lowbit(10)=2$。 再例如,$24$ 在二进制表示下为 $1\underset{\blacktriangle}{\bf1}000$,最低位的 $1$ 的权值为 $8$,因此 $\rm lowbit(24)=8$。 根据位运算知识,可以得到 `lowbit(x) = x & -x`,其中 `&` 为**按位与**运算。

如果一个数减去自己的 \rm lowbit,得到的数再减去自己的 \rm lowbit,不断重复,最终这个数一定会变成 0

例如 7(111)\overset{\!-1}{\longrightarrow}6(110)\overset{\!-2}{\longrightarrow}4(100)\overset{\!-4}{\longrightarrow}0

那么我们要计算 a_{1\dots7} 的和,就只需要求 c_7+c_6+c_4 即可。观察上图,看看是不是这样。

由此我们可以得到查询 a_{1\dots x} 的代码:

int query(int x)
{
    int ans = 0;
    while(x > 0)
    {
        ans += c[x];
        x -= lowbit(x);
    }
    return ans;
}

可以发现,树状数组通过将一段数划分成 O(\log n) 段数的和,从而能够实现高效的查询操作。

如果要求任意一段区间 a_{l\dots r} 的和,可以借助前缀和的思想,用 a_{1\dots r} 的和减去 a_{1\dots l-1} 的和,即 query(r) - query(l-1)。这也说明树状数组可以当成一个支持修改的前缀和来用。

如果要将 a_5 加上一个数 k 该如何处理?观察包含 a_5 的区间,只有 c_5c_6c_8。那么就只需要将 c_5c_6c_8 都加上 k 即可。而 6=5+\rm lowbit(5)8=6+\rm lowbit(6)。也就是说,在树状数组中,一个结点 x 的父亲是 x+{\rm lowbit}(x)。由此我们可以得到将 a_x 加上 k 的代码:

void update(int x, int k)
{
    while(x <= n)
    {
        c[x] += k;
        x += lowbit(x);
    }
}

显然,修改操作的时间复杂度也为 O(\log n)

例题:P3374 【模板】树状数组 1

以下是一份经过封装的极简树状数组代码,可以通过本题。

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
struct BIT{ //树状数组
    int c[N], lowbit(int x){return x & -x;}
    void update(int x, int k){while(x < N) c[x] += k, x += lowbit(x);}
    int query(int x){int s = 0; while(x) s += c[x], x -= lowbit(x); return s;}
} t;
long long n, m;
signed main(){
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin >> n >> m;
    for(int i=1, x; i<=n; i++) cin >> x, t.update(i, x);
    while(m --> 0){
        int op, x, y; cin >> op >> x >> y;
        if(op == 1) t.update(x, y);
        if(op == 2) cout << t.query(y) - t.query(x - 1) << "\n";
    }
    return 0;
}

\color{00cd00}\text{区间修改,单点查询}

例题:P3368 【模板】树状数组 2

借助差分的思想。定义差分数组 d_i = a_i - a_{i-1}。于是有 a_x = \sum\limits_{i=1}^x d_i。如果要在 a_{l\dots r} 加上 k,只需要让 d_l\gets d_l+k,d_{r+1}\gets d_{r+1}-k。使用树状数组维护这一过程即可。

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
struct BIT{ //树状数组维护差分数组
    int c[N], lowbit(int x){return x & -x;}
    void update(int x, int k){while(x < N) c[x] += k, x += lowbit(x);}
    int query(int x){int s = 0; while(x) s += c[x], x -= lowbit(x); return s;}
} t;
int n, m, a[N];
signed main(){
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin >> n >> m;
    for(int i=1; i<=n; i++) cin >> a[i], t.update(i, a[i] - a[i-1]);
    while(m --> 0){
        int op, x, y, k; cin >> op;
        if(op == 1) cin >> x >> y >> k, t.update(x, k), t.update(y + 1, -k);
        if(op == 2) cin >> x, cout << t.query(x) << "\n";
    }
    return 0;
}

\color{00cd00}\text{区间修改,区间查询}

例题:P3372 【模板】线段树 1

我们已经会了使用差分数组实现区间修改。接下来只要考虑如何区间查询。因为 a_i=\sum\limits_{j=1}^i d_j,所以 a_{1\dots x} 的和,即 \sum\limits_{i=1}^x a_i 就等于 \sum\limits_{i=1}^x \sum\limits_{j=1}^i d_j。可以发现对于每一个 d_j 一共加了 x-j+1 次。那么原式等于 \sum\limits_{j=1}^{x} d_j\times (x-j+1),也就是 \sum\limits_{j=1}^x d_j\times (x+1)-d_j\times j。于是发现我们需要 2 个树状数组 c_0,c_1,分别维护 d_jd_j\times j

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e5 + 5;
struct BIT{
    int c[2][N], lowbit(int x){return x & -x;}
    void update(int x, int k){
        for(int i=x; i<N; i+=lowbit(i)){
            c[0][i] += k, c[1][i] += k * x;
        }
    }
    int query(int x){
        int ans = 0;
        for(int i=x; i>0; i-=lowbit(i)){
            ans += c[0][i] * (x + 1) - c[1][i];
        }
        return ans;
    }
} t;
int n, m, a[N];
signed main(){
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin >> n >> m;
    for(int i=1; i<=n; i++) cin >> a[i], t.update(i, a[i] - a[i-1]);
    while(m --> 0){
        int op, x, y, k; cin >> op >> x >> y;
        if(op == 1) cin >> k, t.update(x, k), t.update(y + 1, -k);
        if(op == 2) cout << t.query(y) - t.query(x - 1) << "\n";
    }
    return 0;
}

\color{00cd00}\text{逆序对}

例题:P1908 逆序对

逆序对,就是在一个序列 a_{1\dots n} 中,满足 1\le i \le j \le na_i>a_j 的有序对。

可以用 cnt_x 表示当前 x 出现的数量。从后往前遍历每一个 a_i,当前能与 a_i 匹配的逆序对数量就是 \sum\limits_{j<a_i} cnt_j,即小于 a_i 的数的数量。使用树状数组维护 cnt 数组即可。

你说 a_i\le 10^9,数组开不了那么大?注意到逆序对只关心数的相对大小,所以可以将数据离散化,这样值域就降到了 O(n)。本文不详细展开。

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
struct BIT{
    int c[N], lowbit(int x){return x & -x;}
    void update(int x, int k){while(x < N) c[x] += k, x += lowbit(x);}
    int query(int x){int s = 0; while(x) s += c[x], x -= lowbit(x); return s;}
} t;
long long n, m, ans;
int a[N], rk[N];
signed main(){
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin >> n;
    for(int i=1; i<=n; i++) cin >> a[i], rk[i] = a[i];
    sort(rk + 1, rk + 1 + n); int len = unique(rk + 1, rk + 1 + n) - (rk + 1);
    for(int i=1; i<=n; i++) a[i] = lower_bound(rk + 1, rk + 1 + len, a[i]) - rk;
    for(int i=n; i>=1; i--) ans += t.query(a[i] - 1), t.update(a[i], 1);
    cout << ans;
    return 0;
}

\color{00cd00}\text{二维树状数组}

和一维的树状数组类似,我们设 c_{x,y} 为右下角为 (x,y),向上高为 {\rm lowbit}(x),向左长为 {\rm lowbit}(y) 的矩阵的和。

单点修改时,也和一维一样,i,j 不断加上自己的 \rm lowbit,修改所有包含 a_{x,y}c_{i,j} 即可。

区间查询时,i,j 不断减去自己的 \rm lowbit,累加沿路上的 c_{i,j}

如果要查询任意一个矩阵 a_{x_1,y_1}\sim a_{x_2,y_2} 的和,可以用二维前缀和的方法,即 {\rm sum}(x2,y2)-{\rm sum}(x2,y1\!-\!1)-{\rm sum}(x1\!-\!1,y2)+{\rm sum}(x1\!-\!1,y1\!-\!1)

例题:P4054 [JSOI2009] 计数问题

这题的值域只有 100,因此我们只需要用 100 个二维树状数组分别统计每种权值的数量即可。

#include <bits/stdc++.h>
using namespace std;
const int N = 3e2 + 5;
struct BIT{ //二维树状数组
    int c[N][N], lowbit(int x){return x & -x;}
    void update(int x, int y, int k){
        for(int i=x; i<N; i+=lowbit(i)){
            for(int j=y; j<N; j+=lowbit(j)){
                c[i][j] += k;
            }
        }
    }
    int query(int x, int y){
        int ans = 0;
        for(int i=x; i; i-=lowbit(i)){
            for(int j=y; j; j-=lowbit(j)){
                ans += c[i][j];
            }
        }
        return ans;
    }
    int query(int x1, int y1, int x2, int y2){
        return query(x2, y2) - query(x2, y1-1) - query(x1-1, y2) + query(x1-1, y1-1);
    }
} t[101];
long long n, m, Q;
int a[N][N];
signed main(){
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin >> n >> m;
    for(int i=1; i<=n; i++){
        for(int j=1; j<=m; j++){
            cin >> a[i][j];
            t[a[i][j]].update(i, j, 1);
        }
    }
    for(cin >> Q; Q --> 0;){
        int op, c; cin >> op;
        if(op == 1){
            int x, y; cin >> x >> y >> c;
            t[a[x][y]].update(x, y, -1);
            t[c].update(x, y, 1);
            a[x][y] = c;
        }
        if(op == 2){
            int x1, y1, x2, y2; cin >> x1 >> x2 >> y1 >> y2 >> c;
            cout << t[c].query(x1, y1, x2, y2) << "\n";
        }
    }
    return 0;
}

要实现矩阵修改,也和一维时的方法一样,在二维数组上差分,维护差分数组即可。二维的差分数组定义为 d_{i,j} = a_{i,j} - a_{i-1,j} - a_{i, j-1} + a_{i-1, j-1},它满足 a_{x,y} = \sum\limits_{i=1}^x\sum\limits_{j=1}^y d_{i,j}

如果要同时实现矩阵修改和矩阵查询,也可以先差分,然后推一下式子:

\begin{aligned} &\sum_{p=1}^x\sum_{q=1}^y a_{p,q} \\ =&\sum_{p=1}^x\sum_{q=1}^y\sum_{i=1}^p\sum_{j=1}^q d_{i,j} \\ =&\sum_{i=1}^x\sum_{j=1}^y d_{i,j}\times (x-i+1) \times (y-j+1) \\ =&\sum_{i=1}^x\sum_{j=1}^y d_{i,j}\times(xy+x+y+1)-d_{i,j}\times i\times(y+1)-d_{i,j}\times j\times (x+1)+d_{i,j}\times i\times j \end{aligned}

于是我们用四个二维树状数组,分别维护 d_{i,j},d_{i,j}\times i,d_{i,j}\times j,d_{i,j}\times i \times j 即可。

例题:P4514 上帝造题的七分钟

#include <bits/stdc++.h>
using namespace std;
const int N = 3e3 + 5;
struct BIT{
    int c[4][N][N], lowbit(int x){return x & -x;}
    void update(int x, int y, int k){
        for(int i=x; i<N; i+=lowbit(i)){
            for(int j=y; j<N; j+=lowbit(j)){
                c[0][i][j] += k;
                c[1][i][j] += k * x;
                c[2][i][j] += k * y;
                c[3][i][j] += k * x * y;
            }
        }
    }
    int query(int x, int y){
        int ans = 0;
        for(int i=x; i; i-=lowbit(i)){
            for(int j=y; j; j-=lowbit(j)){
                ans += c[0][i][j] * (x + 1) * (y + 1)
                     - c[1][i][j] * (y + 1)
                     - c[2][i][j] * (x + 1)
                     + c[3][i][j];
            }
        }
        return ans;
    }
    void update(int x1, int y1, int x2, int y2, int k){
        update(x1, y1, k), update(x2+1, y1, -k), update(x1, y2+1, -k), update(x2+1, y2+1, k);
    }
    int query(int x1, int y1, int x2, int y2){
        return query(x2, y2) - query(x2, y1-1) - query(x1-1, y2) + query(x1-1, y1-1);
    }
} t;
long long n, m;
char op;
signed main(){
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin >> op >> n >> m;
    while(cin >> op){
        int x1, y1, x2, y2, k;
        cin >> x1 >> y1 >> x2 >> y2;
        if(op == 'L') cin >> k, t.update(x1, y1, x2, y2, k);
        if(op == 'k') cout << t.query(x1, y1, x2, y2) << "\n";
    }
    return 0;
}

\color{00cd00}\text{权值树状数组}

所谓权值数组,就是将权值作为下标,统计每种权值出现的次数。权值树状数组就是使用树状数组维护权值数组。我们在前面的“逆序对”一节中已经用到了权值树状数组,这里我们利用权值树状数组解决“查询全局第 k 小值”问题。

我们需要实现以下操作:

  1. 在序列中加入一个数 x
  2. 在序列中删除一个数 x
  3. 查询序列中第 k 小的数是多少。

对于操作 1/2,就是在权值数组中将 cnt_x \gets cnt_x\pm 1。对于操作 3,可以考虑二分 x,用树状数组查询小于 x 的数的数量,不断调整直到找到一个 x_0 满足 \operatorname{Sum}(1,x_0)<k\operatorname{Sum}(1,x_0+1)\ge k,此时 x_0+1 即为第 k 小的数。

二分法的时间复杂度为 O(\log^2 n)。实际上,我们有 O(\log n) 的方法解决这个问题。

把二分换成倍增。设 x=0sum=0,枚举 i=\log_2n\to 0

最终得到的 x 是满足 \operatorname{Sum}(1,x)<k 的最大值,x+1 即为第 k 小的数。

根据树状数组的美好性质,查询 \operatorname{Sum}(x+1,x+2^i) 只需要访问 c_{x+2^i} 就行了,不需要 O(\log n) 查询一遍。因此倍增法的时间复杂度仅为 O(\log n)

int get_kth(int k){
    int sum = 0, x = 0;
    for(int i=20; i>=0; i--){
        x += 1 << i;
        if(x > N || sum + c[x] >= k) x -= 1 << i;
        else sum += c[x];
    }
    return x + 1;
}

例题:P3369 【模板】普通平衡树

想不到吧,树状数组还能当平衡树用。

这题比上面还多了查询 x 的排名、前驱、后继的操作。对于求 x 的排名,就是 \operatorname{Sum}(x-1)+1。对于求 x 的前驱/后继,都可以转化为求排名和查询第 k 小的操作。

因为需要离散化,所以只能离线下来做。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
struct BIT{
    int c[N], lowbit(int x){return x & -x;}
    void update(int x, int k){while(x < N) c[x] += k, x += lowbit(x);}
    int query(int x){int s = 0; while(x) s += c[x], x -= lowbit(x); return s;}
    int get_kth(int k){
        int sum = 0, x = 0;
        for(int i=20; i>=0; i--){
            x += 1 << i;
            if(x > N || sum + c[x] >= k) x -= 1 << i;
            else sum += c[x];
        }
        return x + 1;
    }
} t;
int n, m, rl[N];
pair<int, int> q[N];
signed main(){
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin >> n;
    for(int i=1; i<=n; i++){
        auto &[opt, x] = q[i]; 
        cin >> opt >> x;
        if(opt != 4) rl[++m] = x;
    }
    sort(rl+1, rl+1+m), m = unique(rl+1, rl+1+m) - (rl+1);
    for(int i=1; i<=n; i++){
        auto [opt, x] = q[i];
        if(opt != 4) x = lower_bound(rl+1, rl+1+m, x) - rl;
        if(opt == 1) t.update(x, 1);
        if(opt == 2) t.update(x, -1);
        if(opt == 3) cout << t.query(x - 1) + 1 << "\n";
        if(opt == 4) cout << rl[t.get_kth(x)] << "\n";
        if(opt == 5) cout << rl[t.get_kth(t.query(x - 1))] << "\n";
        if(opt == 6) cout << rl[t.get_kth(t.query(x) + 1)] << "\n";
    }
    return 0;
}

可以发现,用权值树状数组实现普通平衡树的代码只有约 \tt{1KB},效率也比一众平衡树高不少。

\color{00cd00}\text{树状数组与 min/max}

需要注意的是,因为 \min/\max 不满足可差分性,所以普通的树状数组不能用于解决 RMQ 问题。但是查询前缀 \min/\max 是可以的,这可以用于一些 DP 的优化,如 P9097。\gcd 等满足结合律但不可差分的运算也是同理。

代码就是把普通树状数组里的 + 换成 \min/\max。有时候可能需要写一个构造函数将 c 初始化为无穷大/无穷小。

struct BIT{
    int c[N], lowbit(int x){return x & -x;}; BIT(){memset(c, 0x3f, sizeof(c));}
    void update(int x, int k){while(x < N) c[x] = min(c[x], k), x += lowbit(x);}
    int query(int x){int s = N; while(x) s = min(s, c[x]), x -= lowbit(x); return s;}
};

事实上,有一种 O(\log^2n) 的方法让树状数组维护不可差分信息,但是用处不大,本文不再赘述。

以上所有代码的树状数组都使用了结构体封装,大家可以直接拿来用 QwQ。