题解:P3373 【模板】线段树 2
lcfollower · · 题解
写在前面
本文为了符合模版题规范,从线段树的建树到查询一步步讲,读者可以自己调阅读的初始位置。
如果文章有错误可以评论,到时候我会更改。
修改
修改只有两天时间,通过了又再次修改,所以在此感谢审核的管理。
引入
先看这题。
首先暴力肯定挂。
考虑运用前缀和优化,但是修改操作的时间复杂度是不能接受的。
再考虑运用 BIT 模板题的方法,但是区间求和的时间复杂度也不能接受。
(修改,后面不会再标注)但是有只用 BIT 的做法,但是不是这题的重点。
于是我们选择用线段树解决。
线段树可分为普通线段树和权值线段树,这题用普通线段树(当然还能衍生出更多的线段树,比如李超线段树等),权值线段树的每一个节点代表数值(有时需要离散化)在这段区间的答案(不一定是问题答案,也可以是某个需要维护的值),这里讲普通线段树。
顾名思义,线段树是一棵树,设当前节点编号为 u << 1,右儿子编号为 u << 1 | 1(u << 1 后二进制末尾为
最开始
做入门线段树的题,首先要明确自己需要维护什么值,想清楚了再打代码。
这里明显需要维护区间和,还有一个等下面会讲(包括原因)。
建树操作
首先建立结构体:
const int N = 5e5 + 10;
struct SGT {//SGT = SegMent Tree = 线段树。
int l ,r ,sum;//sum 表示需要维护的区间和。
} tr[N << 2];
注意开
然后是正式的建树操作。
代码实现中,我们使用 build (u ,l ,r) 表示节点
- 叶节点(递归边界):即
l = r ,节点建立后返回; - 创建左儿子:先取
mid=\frac{l+r}{2} 再递归build (u << 1 ,l ,mid); - 创建右儿子:仍旧有
mid = \frac{l + r}{2} ,然后递归build (u << 1 | 1 ,mid + 1 ,r)。
创建叶节点的节点建立表示将 tr[u].sum 值赋为
最后不要忘了只给叶节点的 pushup)区间信息。
完整建树代码如下:
const int N = 5e5 + 10;
struct node {
int l ,r ,sum;//sum 表示需要维护的区间和。
} tr[N << 2];
inline void pushup(int u){
//pushup(u) 表示将 u 的左儿子和右儿子的信息合并到 u 的信息上。
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
/*注意:由此可得线段树维护的值必须具有结合律,比如区间和、区间积、区间异或和、区间最值等!
参考下述:
tr[1].sum = tr[2].sum + tr[3].sum
= (tr[4].sum + tr[5].sum) + (tr[6].sum + tr[7].sum),
如果加法不具有结合律普通线段树根本维护不了区间和。*/
} inline void build (int u,int l,int r){
tr[u].l = l ,tr[u].r = r;
if(l == r){tr[u].sum = a[L];return;}//叶节点。
int mid = ((L + R) >> 1);//取中点。
/*建立左儿子和右儿子*/
build(u << 1 ,L ,mid);
build(u << 1 | 1 ,mid + 1, R);
pushup(u);//别忘了合并信息。
}
这样我们就建立了一棵和上图相当的线段树。
这样总共会建立至多
以下开始默认
单点修改
单点修改,指的是将某一个下标的数值进行一次操作,可以加、减、乘、除、异或等。
首先先给这个函数定义名为 update。
考虑对于下标 pushup。因为可以发现未遍历过的节点都无需更改,所以 pushup 只要回溯时进行即可。
题目中是进行区间加法。
具体代码实现如下:
inline void update(int u ,int x ,int v){
//u 表示当前节点编号,x 表示需要修改的序列下标,v 为增加值。
int l = tr[u].l ,r = tr[u].r;
if(l == r){//就是这个叶子节点。
//由于这是逐步逼近这个节点,搜索到的只要是叶节点就一定 l = r = x,所以可以不写 l(r) == x。
tr[u].sum += (r - l + 1) * v;
//区间每个数增加 v,总和增加为:区间长度 * v。
return;//不要忘了回溯!
}
int mid = ((l + r) >> 1);
if(x <= mid) update(u << 1 ,L ,R ,v);//如果在左子树。
else update(u << 1 | 1, L , R ,v);//否则当然在右子树啦。
pushup(u);//合并信息,因为修改过了。
}
考虑到每次走到一个儿子,都会排除约一半的节点,所以
区间修改
就是题目所求操作。
聪明的读者学会单点修改后,可以执行
于是我们引进一个新的概念:懒标记。
我们定义这题的懒标记为其所有儿子节点需要增加的值,「下传」(下面会有解释)后清零。
一开始所有节点的懒标记为
设当前节点
- 如果
[l,r]\in [L,R] ,那么把u 的懒标记增加v 并对其区间和进行修改,然后回溯,因为整个区间都会在修改区间内,因此可以直接修改。 - 否则,如果
L\le mid ,说明左子树有需要修改的区间,往左子树递归。 - 如果
mid + 1 \le R ,则说明右子树有需要修改的区间,往右子树递归。
最后不要忘了 pushup。
完成后如下图:
然后考虑修改
此时会在
为什么呐?原因是程序并没有把 pushdown。
此时我们将
最最重要的,已经「下传」了,记得把
弄清楚懒标记和其下传后,我们给它一个名字:
结构体如下:
struct SGT{
int l ,r ,sum ,lazy;//lazy 为懒标记。
}tr[N << 2];
pushdown 代码如下:
inline void pushdown(int u){
//把节点 u 的懒标记「下传」到他的儿子。
if(tr[u].lazy){//这个节点没有懒标记可以不「下传」。
/*左、右儿子懒标记分别加上 tr[u].lazy。*/
tr[u << 1].lazy += tr[u].lazy;
tr[u << 1 | 1].lazy += tr[u].lazy;
/*更改左、右儿子的区间和。*/
tr[u << 1].sum += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].lazy;
tr[u << 1 | 1].sum += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].lazy;
tr[u].lazy = 0; //清空懒标记!
}
}
效果如下:
update 代码如下:
inline void update(int u,int L,int R,int v){
//当前节点为 u,需要修改的区间为 [L,R],需要增加 v。
int l = tr[u].l ,r = tr[u].r;
if(l >= L && r <= R){//节点对应的区间完全在修改区间内。
tr[u].sum += (r - l + 1) * v;//更改区间和。
tr[u].lazy += v;//增加懒标记。
return;
}
pushdown(u);//「下传」懒标记。
int mid = ((l + r) >> 1);
/* 修改。
注意:不能写 else,因为 L <= mid 的时也可能满足 mid < R。
*/
if(L <= mid) update(u << 1 ,L ,R ,v);
if(mid < R) update(u << 1 | 1, L , R ,v);
pushup(u);//合并信息。
}
这样时间复杂度就为
::::info[证明]
我们仍旧设当前节点代表区间为
以及两个定义:
-
完整节点:
[l ,r] \in [L,R] 。 -
部分节点:
[l,r]\cap [L,R] \neq \varnothing 且[l,r]\notin [L,R] 。
:::info[最多
我们发现,部分节点一定会出现在区间两端,因此最多
因为我们最多向下递归
:::
:::info[最多两个完整节点] 考虑到如果两个完整节点互为兄弟节点,那么它们的父亲一定也是完整节点,会在上一层回溯掉。
又考虑线段树是二叉树,因此两个完整节点来自左右两个区间(因为如果两个完整节点互为兄弟节点那么一定会在它们的父节点回溯)。 :::
所以一层最多访问
直接上图片理解:
:::align{center} :::
::::
单点查询
同理也很容易写出来,注意需要查询所以沿途的所有懒标记都要「下传」,因为我们想要的答案是要真实的,而非打上懒标记的。
inline int query(int u,int x){
//当前节点为 u,需要查询下标为 x。
int l = tr[u].l ,r = tr[u].r;
if(l == x) return tr[u].sum;
pushdown(u);//下传懒标记。
int mid = ((l + r) >> 1) ,val = 0;
/* 计算左、右子树答案。 */
if(L <= mid) val += query(u << 1 ,L , R);
if(mid < R) val += query(u << 1 | 1, L ,R);
return val;
}
时间复杂度为
区间查询
同理,稍微改一下就可以了,时间复杂度仍为
inline int query(int u,int L,int R){
//当前节点为 u,需要查询的区间为 [L ,R]。
int l = tr[u].l ,r = tr[u].r;
if(l >= L && r <= R) return tr[u].sum;//[L ,R] 如果包含 [l ,r] 就直接返回答案。
pushdown(u);
int mid = ((l + r) >> 1) ,val = 0;
if(L <= mid) val += query(u << 1 ,L , R);
if(mid < R) val += query(u << 1 | 1, L ,R);
return val;
}
时间复杂度为
完整代码
把上面的代码组合一下就可以了。
总体时间复杂度为
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define up(i,x,y) for(register int i=x;i<=y;++i)
using namespace std;
inline int read(){int x=0;bool f=0;char ch=getchar();while(!isdigit(ch)){f|=(ch=='-');ch=getchar();}while(isdigit(ch))x=(x<<1)+(x<<3)+(ch^48),ch=getchar();return (f?-x:x);}
inline void write(int x){if(x<0)putchar('-'),x=-x;if(x>9)write(x/10);putchar(x%10|48);}
inline void writeln(int x){write(x),putchar('\n');}
inline void writesp(int x){write(x),putchar(' ');}
const int N = 1e5 + 10;
int n ,Q ,a[N];
struct SGT{int l ,r ,sum ,lazy;}tr[N << 2];
inline void pushup(int u){
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
} inline void pushdown(int u){
if(tr[u].lazy){
tr[u << 1].lazy += tr[u].lazy;
tr[u << 1 | 1].lazy += tr[u].lazy;
tr[u << 1].sum += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].lazy;
tr[u << 1 | 1].sum += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].lazy;
tr[u].lazy = 0;
}
}inline void build(int u,int L,int R){
tr[u].l = L ,tr[u].r = R;
if(L == R){tr[u].sum = a[L];return;}
int mid = ((L + R) >> 1);
build(u << 1 ,L ,mid);
build(u << 1 | 1 ,mid + 1, R);
pushup(u);
} inline void update(int u,int L,int R,int v){
int l = tr[u].l ,r = tr[u].r;
if(l >= L && r <= R){
tr[u].sum += (tr[u].r - tr[u].l + 1) * v;
tr[u].lazy += v;
return;
}
pushdown(u);
int mid = ((l + r) >> 1);
if(L <= mid) update(u << 1 ,L ,R ,v);
if(mid < R) update(u << 1 | 1, L , R ,v);
pushup(u);
} inline int query(int u,int L,int R){
int l = tr[u].l ,r = tr[u].r;
if(l >= L && r <= R) return tr[u].sum;
pushdown(u);
int mid = ((l + r) >> 1) ,val = 0;
if(L <= mid) val += query(u << 1 ,L , R);
if(mid < R) val += query(u << 1 | 1, L ,R);
return val;
}signed main(){
n = read() , Q = read();
up(i, 1 ,n) a[i] = read();
build(1 ,1, n);
while(Q --){
int op = read() , L = read() , R = read();
if(op == 1) {int v = read();update(1 ,L ,R ,v);}
if(op == 2) writeln(query(1 ,L ,R));
}
return 0;
}
/*
Input:
5 5
1 5 4 2 3
2 2 4
1 2 3 2
2 3 4
1 1 5 1
2 1 4
Output:
11
8
20
*/
回到本题
本题额外多了一个操作:区间乘法。
因此我们多维护一个懒标记
考虑到乘法分配律:
但是还有一个问题:乘法标记和加法标记,我们先下传哪个呢?又或者说顺序没关系?
我们来仔细想一下乘法标记(记作
-
会让(左右儿子的,后面省略)
sum 变成sum\times Mul (区间乘)。 -
会让
mul 变成mul\times Mul (显然)。 -
会让
add 变成add\times Mul (显然)。
还有加法标记(
-
会让
sum 变成sum + len\times Add ,len 为区间长度(区间加)。 -
会让
add 变成add + Add (显然)。 -
会让
mul ……等等!好像对mul 无影响!
也就是说,除了本身,
我们下传标记需要满足从无依赖到有依赖(和拓扑排序挺像,理由也可以借鉴),因此先下传
注意是类似,可能有的地方不太说得通。
:::info[然后就可以来看代码了。]
修改:这里修改了码风并且修改了
inline void pushdown(int u) {
# define mul(u) tr[u].mul
# define add(u) tr[u].add
# define sum(u) tr[u].sum
int l = tr[u].l ,r = tr[u].r ,mid = ((l + r) >> 1);
/* 上面适当偷懒(逃)。*/
if (tr[u].mul != 1){
int Mul = mul(u);
/* Mul 对各个信息的影响。*/
sum(u << 1) *= Mul ,sum(u << 1) %= p;
sum(u << 1 | 1) *= Mul ,sum(u << 1 | 1) %= p;
mul(u << 1) *= Mul ,mul(u << 1) %= p;
mul(u << 1 | 1) *= Mul ,mul(u << 1 | 1) %= p;
add(u << 1) *= Mul ,add(u << 1) %= p;
add(u << 1 | 1) *= Mul ,add(u << 1 | 1) %= p;
mul(u) = 1;//清空,注意是 1。
} if (tr[u].add) {
int Add = add(u);
/* Add 对各个信息的影响。*/
sum(u << 1) += (mid - l + 1) * Add % p ,sum(u << 1) %= p;
sum(u << 1 | 1) += (r - mid) * Add % p ,sum(u << 1 | 1) %= p;
add(u << 1) += Add ,add(u << 1) %= p;
add(u << 1 | 1) += Add ,add(u << 1 | 1) %= p;
add(u) = 0;
}
}
:::
还有注意地方是注意取模,以及各个修改函数是否把所有影响信息修改到位。
:::info[完整代码]
修改:修改了码风。
#include<bits/stdc++.h>
#define int long long
#define rep(i ,x ,y) for(int i = x ; i <= y ; i ++)
#define pr printf
using namespace std;
inline int read(){int x = 0, f = 0;char ch = getchar();while (!isdigit(ch)) {f |= (ch == '-');ch = getchar();}while (isdigit(ch))x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();return f ? -x : x;}
const int N = 1e5 + 10;
struct SGT {
int l ,r ,sum ,mul ,add;
} tr[N << 2];
int n ,Q ,p ,k ,a[N];
inline void pushup(int u) {
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
inline void pushdown(int u) {
# define mul(u) tr[u].mul
# define add(u) tr[u].add
# define sum(u) tr[u].sum
int l = tr[u].l ,r = tr[u].r ,mid = ((l + r) >> 1);
if (tr[u].mul != 1){
int Mul = mul(u);
sum(u << 1) *= Mul ,sum(u << 1) %= p;
sum(u << 1 | 1) *= Mul ,sum(u << 1 | 1) %= p;
mul(u << 1) *= Mul ,mul(u << 1) %= p;
mul(u << 1 | 1) *= Mul ,mul(u << 1 | 1) %= p;
add(u << 1) *= Mul ,add(u << 1) %= p;
add(u << 1 | 1) *= Mul ,add(u << 1 | 1) %= p;
mul(u) = 1;
} if (tr[u].add) {
int Add = add(u);
sum(u << 1) += (mid - l + 1) * Add % p ,sum(u << 1) %= p;
sum(u << 1 | 1) += (r - mid) * Add % p ,sum(u << 1 | 1) %= p;
add(u << 1) += Add ,add(u << 1) %= p;
add(u << 1 | 1) += Add ,add(u << 1 | 1) %= p;
add(u) = 0;
}
}
inline void build(int u ,int l ,int r) {
tr[u].l = l;
tr[u].r = r;
tr[u].mul = 1;
if(l == r) {
tr[u].sum = a[l] % p;
return;
}
int mid = ((l + r) >> 1);
build(u << 1 ,l ,mid);
build(u << 1 | 1 ,mid + 1 ,r);
pushup(u);
}
inline void muls(int u,int L,int R,int d) { //乘法操作。
int l = tr[u].l ,r = tr[u].r ,mid = ((l + r) >> 1);
if(l >= L && r <= R) {
tr[u].add = tr[u].add * d % p;
tr[u].mul = tr[u].mul * d % p;
tr[u].sum = tr[u].sum * d % p;//修改方式都 * d,add 和 sum 是因为结合律。(
return;
}
pushdown(u);
if(L <= mid) muls (u << 1 ,L ,R ,d);
if(mid < R) muls (u << 1 | 1 ,L ,R ,d);
pushup(u);
}
inline void adds(int u , int L , int R , int d) {
int l = tr[u].l ,r = tr[u].r ,mid = ((l + r) >> 1);
if(l >= L && r <= R) {
tr[u].add = (tr[u].add + d) % p;
tr[u].sum = (tr[u].sum + (r - l + 1) * k % p) % p;
return;
}
pushdown(u);
if(L <= mid) adds(u << 1 ,L ,R ,d);
if(mid < R) adds(u << 1 | 1 ,L ,R ,d);
pushup(u);
}
inline int query(int u , int L , int R) {
int l = tr[u].l ,r = tr[u].r ,mid = ((l + r) >> 1);
if(l >= L && r <= R) return tr[u].sum;
pushdown(u);
int val = 0;
if(L <= mid) val = (val + query(u << 1 ,L ,R) % p) % p;
if(mid < R) val = (val + query(u << 1 | 1 ,L ,R) % p) % p;
return val;
}
signed main() {
n = read() ,Q = read() ,p = read();
rep(i ,1 ,n) a[i] = read();
build(1 ,1 ,n);
while(Q --) {
int op = read() ,x = read() ,y = read();
if(op == 1) {
k = read();
muls (1 ,x ,y ,k);
} else if(op == 2) {
k = read();
adds (1 ,x ,y ,k);
} else pr("%lld\n" , query (1 , x , y));
}
return 0;
}
:::
练习一下标记关系
比如有区间覆盖标记
因为
所以先下放
另一种理解方式
假设当前是
- 区间覆盖操作可以直接清空
add 和mul ,管你前面怎么加怎么减。
假设当前是
-
如果存在
cov ,需要把cov 下放,这样才能更新mul 和add 。 -
不存在最好。
因此
最后
模版题到此结束!
如果是区间异或、最值等只需要改一下 pushup 和 pushdown 还有 update 内的一些不固定代码。
区间翻转、反转等不会可以看别的题解,你会有一种突然悟了的感觉,维护的信息也有点套路。
如果想要线段树打得很熟练,就要多打,每次打最好换题,也可以还是同样的题,打多了就习惯了。
做多了你会发现最难的是 pushdown,写的时候要注意标记的下传顺序以及互相的影响关系。
如果有帮助请给个赞。
真的最后
如果你觉得二叉太慢了想写三叉完全没问题,虽然能节省时间但是代码细节增多,所需要空间也增多。
一般题目不会卡时限,所以二叉就足够了。
::::info[与模板无关内容] 怎么能没有 LCT 做法呢?
给每个
但是这题没有断边和加边操作,加上断边和加边操作应该就是这题了。
:::info[代码]
# include <bits/stdc++.h>
# define int long long
# define up(i, x, y) for (int i = x; i <= y; i++)
# define dn(i, x, y) for (int i = x; i >= y; i--)
# define inf 1e18
using namespace std;
inline int read(){int x = 0, f = 0;char ch = getchar();while (!isdigit(ch)) {f |= (ch == '-');ch = getchar();}while (isdigit(ch))x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();return f ? -x : x;}
inline void write(int x){if (x < 0)putchar('-'), x = -x;if (x > 9)write(x / 10);putchar(x % 10 | 48);}
inline void writeln(int x){write(x), putchar('\n');}
inline void writesp(int x){write(x), putchar(' ');}
const int N = 1e5 + 10;
int n ,m ,a[N] ;
# define lc(u) tr[u].ch[0]
# define rc(u) tr[u].ch[1]
# define fa(u) tr[u].p
# define rev(u) tr[u].rev
# define sum(u) tr[u].sum
# define value(u) tr[u].value
# define add(u) tr[u].add
# define mul(u) tr[u].mul
# define sz(u) tr[u].sz
struct Link_Cut_Tree {
int p ,ch[2] ,rev ,sum ,value ,add ,mul ,sz;
}tr[N];
//LCT 节点数的个数是子树大小,所以要维护 sz。
int mod;
namespace LCT{
inline bool isroot (int x) {return !(lc(fa(x)) == x || rc(fa(x)) == x);}
inline void pushup (int x) {
sum(x) = (sum(lc(x)) + sum(rc(x)) + value(x)) % mod;
sz(x) = sz(lc(x)) + sz(rc(x)) + 1;
} inline bool get (int x) {return rc(fa(x)) == x;}
inline void pushmul (int x ,int k) {//下传乘法标记。
mul(x) *= k ,mul(x) %= mod;
add(x) *= k ,add(x) %= mod;
value(x) *= k ,value(x) %= mod;
sum(x) *= k ,sum(x) %= mod;
} inline void pushadd (int x ,int k) {//下传加法标记。
add(x) += k ,add(x) %= mod;
value(x) += k ,value(x) %= mod;
sum(x) += k * sz(x) % mod ,sum(x) %= mod;
} inline void pushdown (int x) {
if (mul(x) ^ 1) {
pushmul (lc(x) ,mul(x)) ,
pushmul (rc(x) ,mul(x));
}
if (add (x)) {
pushadd (lc(x) ,add(x)) ,
pushadd (rc(x) ,add(x));
} if (rev(x)){
swap(lc(x) ,rc(x));
rev(lc(x)) ^= 1;
rev(rc(x)) ^= 1;
}
mul(x) = 1 ,add(x) = 0 ,rev(x) = 0;;
} inline void alldown (int x) {
if (!isroot(x)) alldown(fa(x));
pushdown(x);
} inline void rotate (int x) {
int y = fa(x) ,z = fa(y);
bool k = get(x);
if (!isroot(y)) tr[z].ch[get(y)] = x;
fa(x) = z;
fa(tr[x].ch[k ^ 1]) = y ,tr[y].ch[k] = tr[x].ch[k ^ 1];
tr[x].ch[k ^ 1] = y ,fa(y) = x;
pushup (y) ,pushup (x);
} inline void splay (int x) {
alldown (x);
while (!isroot(x)) {
int f = fa(x);
if (!isroot(f))
if (get(f) == get(x)) rotate (f);
else rotate (x);
rotate (x);
}
} inline void access (int x) {
int y = 0;
while (x) {
splay (x) ,rc(x) = y ,y = x;
pushup (x) ,x = fa(x);
}
} inline void makeroot (int x) {
access (x) ,splay (x) ,rev(x) ^= 1;
} inline int find (int x) {
access (x) ,splay (x);
while (lc(x))
pushdown (x) ,x = lc(x);
splay (x);
return x;
} inline void split (int x ,int y) {
makeroot (x);
access (y);
splay (y);
} inline void link (int x ,int y) {
makeroot (x);
if (find(y) ^ x) fa(x) = y;
} inline void cut (int x ,int y) {
makeroot (x);
if (find(y) == x && fa(y) == x && !lc(y)) fa(y) = rc(x) = 0;
}
} namespace lolcrying {
signed main () {
n = read () ,m = read () ,mod = read ();
up (i ,1 ,n) value(i) = read ();
up (i ,1 ,n - 1) LCT :: link (i ,i + 1);//连边。
while (m --) {
int op = read () ,x = read () ,y = read();
if (op == 2){
int k = read ();
LCT :: split (x ,y);
LCT :: pushadd (y ,k);
} if (op == 3) {
LCT :: split (x ,y) ;
writeln (sum(y));
}
if (op == 1){
int k = read ();
LCT :: split (x ,y);
LCT :: pushmul (y ,k);
}
}
return 0;
}
}
signed main () {
int T = 1;
while (T --) lolcrying :: main ();
return 0;
}
:::
::::
参考文献
-
oi-wiki。
-
以及洛谷两个线段树模板的题解。
鸣谢
-
审核的管理。
-
指出过有效错误过的:@shawn0618,@Outer_Horizon。
-
指出过绝对的语言的:@return_third。
欢迎指出更多错误!