void update(int x, int d) {
for (x += n; x; x >>= 1) sgt[x] += d;
}
int query(int l, int r) {
int ans = 0;
for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
if (l&1) ans += sgt[l++];
if (r&1) ans += sgt[--r];
}
return ans;
}
传统的 1 开始闭区间代码:
void update(int x, int d) {
for (x += n - 1; x >= 1; x >>= 1) sgt[x] += d;
}
int query(int l, int r) {
int ans = 0;
for (l += n - 1, r += n - 1; l <= r; l >>= 1, r >>= 1) {
if (l&1) ans += sgt[l++];
if (~r&1) ans += sgt[r--];
}
return ans;
}
然后,这个东西也可以带标记(显然),甚至支持标记下传。对于单点操作,下传它的所有祖先;对于一个区间 [l,r],下传 l 和 r 的所有祖先即可。
更进一步地,这个 b 甚至完全没必要是整数!我们以上的处理过程中完全未利用树结构的性质(叉数固定)。
所以稍加修改就可以获得一个任意实数叉线段树的代码:
void update(int x, int d) {
for (x += n; x; x /= b) sgt[x] += d;
}
int query(int l, int r) {
int ans = 0;
for (l += n, r += n; int(l / b) != int(r / b); l /= b, r /= b) {
while (int(l / b) == int((l - 1) / b)) ans += sgt[l++];
while (int(r / b) == int((r - 1) / b)) ans += sgt[--r];
}
return accumulate(sgt + l, sgt + r, ans);
}
```cpp
void update(int x, int d) {
for (x += n - 1; x; x /= b) sgt[x] += d;
}
int query(int l, int r) {
int ans = 0;
for (l += n - 1, r += n - 1; int(l / b) != int(r / b); l /= b, r /= b) {
while (int(l / b) == int((l - 1) / b)) ans += sgt[l++];
while (int(r / b) == int((r + 1) / b)) ans += sgt[r--];
}
return accumulate(sgt + l, sgt + r + 1, ans);
}
```
当然,为了更好的常数,可以改为乘倒数或者预处理父亲(但是这个东西本来就是娱乐向 :P)
树状数组 1 的完整通过代码($1$ 下标闭区间,另一种写法参考原文):
```cpp
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 5e5 + 9;
constexpr double b = 2;
int n, m, sgt[N << 1];
void update(int x, int d) {
for (x += n - 1; x; x /= b) sgt[x] += d;
}
int query(int l, int r) {
int ans = 0;
for (l += n - 1, r += n - 1; int(l / b) != int(r / b); l /= b, r /= b) {
while (int(l / b) == int((l - 1) / b)) ans += sgt[l++];
while (int(r / b) == int((r + 1) / b)) ans += sgt[r--];
}
return accumulate(sgt + l, sgt + r + 1, ans);
}
signed main() {
cin.tie(nullptr)->sync_with_stdio(false);
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> sgt[i + n - 1];
for (int i = (n << 1) - 1; i >= b; --i) sgt[int(i / b)] += sgt[i];
for (int op; m; --m)
if (cin >> op; op == 1) {
int x, d;
cin >> x >> d;
update(x, d);
} else {
int l, r;
cin >> l >> r;
cout << query(l, r) << '\n';
}
return cout << flush, 0;
}
```