题解:P3373 【模板】线段树 2

· · 题解

刚学习线段树肯定是要先写模板题。

如果你刚学习线段树,那么请先完成模板题1,再转移至此题。

算法介绍:

也许当你初学的时候,总是听到学习提高组的选手们提到线段树,所以可能会觉得线段树很难。

那么线段树到底是什么呢?

线段树是一个用来维护区间信息的数据结构。还有,线段树可以维护的内容需要满足可合并性。

一般来说,线段树都是构造成一棵二叉树。以维护区间和为例,线段树的结构如下图所示:

如图所示,设一个区间的值为 [1,2,3,4]47 号节点是叶子结点,赋初始值。2 号节点表示的区间为 [1,2],它的值为 4 号节点和 5 号节点之和。剩余节点同理。

过程:

如上文的图所示,我们要先将数组的值转化成树,也就是建树。学过二叉树的应该都知道,二叉树的一个节点 x 的左右孩子分别为 x \times 2x \times 2 + 1

以维护区间和为例。

建树: 考虑递归建树,具体实现如下:

inline int ls(int x){return x<<1;}//左孩子,右移一位相当于乘2
inline int rs(int x){return x<<1|1;}//右孩子,因为乘以了2,二进制末尾为0,异或1相当于加1
inline void push_up(int x){ans[x]=ans[ls(x)]+ans[rs(x)];}//将节点x的两个孩子的值赋值到x节点上
inline void build(int x,int l,int r){//x号节点维护的区间为[l,r]
    if(l==r){//如果x号节点维护一个点,直接赋值
        ans[x]=a[l];
        return;
    }
    int mid=l+r>>1;
    build(ls(x),l,mid);//左区间
    build(rs(x),mid+1,r);//右区间
    push_up(x);//合并
}

线段树的修改有两种,一种是区间修改,另一种是单点修改。类似的,查询有区间查询和单点查询。

单点修改: 首先单点修改,我们只要修改会受它影响的节点,也就是它的所有祖先结点,包括它自己。从节点 1 开始,向左右子树递归,向所要修改的点靠近。时间复杂度 \mathcal O(\log n)

具体代码实现:

inline void update(int x,int nowl,int nowr,int y,int k){//x号节点维护的区间为[nowl,nowr],修改的点为y
    if(l==nowl&&nowr==r){//找到x号节点
        ans[x]+=k;//x号节点加k
        return;
    }
    int mid=nowl+nowr>>1;//分成两个区间
    if(l<=mid)update(ls(x),nowl,mid,l,r,k);//向左子树寻找
    if(mid<r)update(rs(x),mid+1,nowr,l,r,k);//向右子树寻找
    push_up(x);//重新合并
}

区间修改: 其次区间修改就是修改一整个区间,那么可以运用上文的单点修改一个点一个点逐个修改过去,但是时间复杂度高,最坏情况能达到 \mathcal O(n \log n)

由此考虑优化,给出懒标记的意义。

简单来说,懒标记的作用就是延迟修改的操作,在不必要的时候标记,必要时再更改。

那么什么是必要的时候和不必要的时候呢?

你可以这样想,一直修改但没有查询,这样你修改后没有意义,这就是不必要的时候,当你查询时,就可以利用所标记的懒标记求值,这是必要的时候。运用懒标记后时间复杂度为 \mathcal O(\log n)

具体代码实现:

inline void f(int x,int l,int r,int k){//更新左或右孩子的懒标记和所维护的区间值
    tag[x]+=k;
    ans[x]+=k*(r-l+1); 
}
inline void push_down(int x,int l,int r){
    int mid=l+r>>1;
    f(ls(x),l,mid,tag[x]);//传到左子树
    f(rs(x),mid+1,r,tag[x]);//传到右子树
    tag[x]=0;//清空标记
}
inline void update(int x,int nowl,int nowr,int l,int r,int k){//x号节点维护区间[nowl,nowr],区间[l,r]内的每个值+k
    if(l<=nowl&&nowr<=r){//x号节点所维护的区间在所要修改的区间内
        ans[x]+=k*(nowr-nowl+1);//x号节点维护的区间长度为nowr-nowl+1,所以要加(nowr-nowl+1)个l
        tag[x]+=k;
        return;
    }
    push_down(x,nowl,nowr);//下传懒标记
    int mid=nowl+nowr>>1;
    if(l<=mid)update(ls(x),nowl,mid,l,r,k);//向左子树寻找
    if(mid<r)update(rs(x),mid+1,nowr,l,r,k);//向右子树寻找
    push_up(x);//重新合并
}

最后是查询,与修改类似。不多赘述,具体见代码。

单点查询:

inline int query(int x,int nowl,int nowr,int y){//x号节点维护的区间为[nowl,nowr],查询的点为y
    if(l==nowl&&nowr==r)return ans[x];
    int mid=nowl+nowr>>1,res=0;//分成两个区间
    if(l<=mid)res+=query(ls(x),nowl,mid,l,r);//向左子树寻找
    if(mid<r)query(rs(x),mid+1,nowr,l,r);//向右子树寻找
    return res;
}

区间查询:

inline void f(int x,int l,int r,int k){//更新左或右孩子的懒标记和所维护的区间值
    tag[x]+=k;
    ans[x]+=k*(r-l+1); 
}
inline void push_down(int x,int l,int r){
    int mid=l+r>>1;
    f(ls(x),l,mid,tag[x]);//传到左子树
    f(rs(x),mid+1,r,tag[x]);//传到右子树
    tag[x]=0;//清空标记
}
inline void query(int x,int nowl,int nowr,int l,int r){//x号节点维护区间[nowl,nowr],查询区间[l,r]内的每个值之和
    if(l<=nowl&&nowr<=r){return ans[x];}//x号节点所维护的区间在所要查询的区间内
    push_down(x,nowl,nowr);//下传懒标记
    int mid=nowl+nowr>>1;
    if(l<=mid)res+=query(ls(x),nowl,mid,l,r);//向左子树寻找
    if(mid<r)res+=query(rs(x),mid+1,nowr,l,r);//向右子树寻找
    return res;
}

时间复杂度都为 \mathcal O(\log n)

线段树基本操作讲完了,现在回到这道题。

思路:

这道题有两种修改:

  1. 区间 [l,r] 中每个数乘以 k
  2. 区间 [l,r] 中每个数加上 k

那么修改时加一个懒标记维护乘法运算即可。

解释起来太麻烦,具体原因结合代码即注释:

//注意取模
inline void f(int x,int l,int r,int k1,int k2){//k1为乘法懒标记,k2为加法懒标记
    ans[x]=(k1*ans[x]+(r-l+1)*k2%mod)%mod;//更新答案
/*
当区间乘以k时,两个懒标记以及答案都乘以k
当区间加上k时,只更新了答案和加法懒标记
所以答案更新时要先乘以乘法懒标记,再加上加法懒标记
否则更新后会得到错误的答案
*/
    (tag1[x]*=k1)%=mod;//更新乘法懒标记
    tag2[x]=(k2+k1*tag2[x])%mod;//更新加法懒标记
/*
证明:
设原值为x,原乘法懒标记为t1,原加法懒标记为t2
现乘法懒标记为T1,现加法懒标记为T2
则父节点懒标记下传前,修改后x为x*t1+t2*(r-l+1)
父节点懒标记下传后x为
(x*t1+t2*(r-l+1))*T1+T2*(r-l+1)
=x*t1*T1+T1*t2*(r-l+1)+T2*(r-l+1)
=x*t1*T1+(T1*t2+T2)*(r-l+1)
所以懒标记这样更新
*/
}
inline void push_down(int x,int l,int r){
//tag1是乘法懒标记,tag2是加法懒标记
    int mid=l+r>>1;
    f(ls(x),l,mid,tag1[x],tag2[x]);//懒标记下传到左子树
    f(rs(x),mid+1,r,tag1[x],tag2[x]);//懒标记下传到右子树
    tag1[x]=1,tag2[x]=0;//懒标记清空,乘法懒标记为1
}

最后的完整代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
const int N=1e5+5;
int n,m,mod;
int a[N];
int ans[N*4];
int tag1[N*4],tag2[N*4];
int ls(int x){return x<<1;}
int rs(int x){return x<<1|1;}
void push_up(int x){ans[x]=ans[ls(x)]+ans[rs(x)];}
void build(int x,int l,int r){
    tag1[x]=1;
    if(l==r){ans[x]=a[l];return;}
    int mid=l+r>>1;
    build(ls(x),l,mid);
    build(rs(x),mid+1,r);
    push_up(x);
}
inline void f(int x,int l,int r,int k1,int k2){
    ans[x]=(k1*ans[x]+(r-l+1)*k2%mod)%mod;
    (tag1[x]*=k1)%=mod;
    tag2[x]=(k2+k1*tag2[x])%mod;
}
inline void push_down(int x,int l,int r){
    int mid=l+r>>1;
    f(ls(x),l,mid,tag1[x],tag2[x]);
    f(rs(x),mid+1,r,tag1[x],tag2[x]);
    tag1[x]=1,tag2[x]=0;
}
inline void update1(int x,int nowl,int nowr,int l,int r,int k){
    if(l<=nowl&&nowr<=r){
        (ans[x]*=k)%=mod;
        (tag1[x]*=k)%=mod;
        (tag2[x]*=k)%=mod;
        return;
    }
    push_down(x,nowl,nowr);
    int mid=nowl+nowr>>1;
    if(l<=mid)update1(ls(x),nowl,mid,l,r,k);
    if(mid<r)update1(rs(x),mid+1,nowr,l,r,k);
    push_up(x);
}
inline void update2(int x,int nowl,int nowr,int l,int r,int k){
    if(l<=nowl&&nowr<=r){
        (ans[x]+=k*(nowr-nowl+1)%mod)%=mod;
        (tag2[x]+=k)%=mod;
        return;
    } 
    push_down(x,nowl,nowr);
    int mid=nowl+nowr>>1;
    if(l<=mid)update2(ls(x),nowl,mid,l,r,k);
    if(mid<r)update2(rs(x),mid+1,nowr,l,r,k);
    push_up(x); 
}
inline int query(int x,int nowl,int nowr,int l,int r){
    if(l<=nowl&&nowr<=r)return ans[x];
    int res=0;
    push_down(x,nowl,nowr);
    int mid=nowl+nowr>>1;
    if(l<=mid)(res+=query(ls(x),nowl,mid,l,r))%=mod;
    if(mid<r)(res+=query(rs(x),mid+1,nowr,l,r))%=mod;
    return res%mod; 
}
signed main(){
    IOS;
    cin>>n>>m>>mod;
    for(int i=1;i<=n;i++)cin>>a[i],a[i]%=mod;
    build(1,1,n);
    for(int i=1;i<=m;i++){
        int op;
        cin>>op;
        if(op==1){
            int l,r,k;
            cin>>l>>r>>k;
            update1(1,1,n,l,r,k);
        }
        else if(op==2){
            int l,r,k;
            cin>>l>>r>>k;
            update2(1,1,n,l,r,k);
        }
        else{
            int l,r;
            cin>>l>>r;
            cout<<query(1,1,n,l,r)<<endl; 
        }
    }return 0;
}

最后,线段树并不难,只要你多写就能越来越熟练,并不需要背模板。