P3373 【模板】线段树 2:一个带有乘和加的模板 - 题解

· · 题解

正文前提示

题目传送门

主要思路

其实,就是在线段树 1 的基础上加上了一个乘法。

具体做法

Step \mathbf1:主函数逻辑

读入 n 个数,建树,接下来 q 次操作,是 1 就将区间做乘法,是 2 就将区间做加法,是 3 就输出区间和。

主要代码:

    LL n,q;
    scanf("%lld%lld%lld",&n,&q,&m);
    for(int i=1;i<=n;i++){
        scanf("%lld",&w[i]);
    }
    build(1,1,n);//建树
    while(q--){
        int op;
        scanf("%d",&op);
        if(op==1){
            LL x,y,k;
            scanf("%lld%lld%lld",&x,&y,&k);
            update(1,x,y,k,0);//从节点 1 开始,范围是 x ~ y,给每一个数乘 k,给每一个数加 0。
        }else if(op==2){
            LL x,y,k;
            scanf("%lld%lld%lld",&x,&y,&k);
            update(1,x,y,1,k);//从节点 1 开始,范围是 x ~ y,给每一个数乘 1,给每一个数加 k。
        }else{
            LL x,y;
            scanf("%lld%lld",&x,&y);
            printf("%lld\n",query(1,x,y));//从节点 1 开始,范围是 x ~ y,查询区间和。
        }
    }

Step \mathbf2:建树

先声明这个节点的范围是 l\sim r 的;其次,把这个节点的 sum 初值置为 w_lw_r(你爱写哪个写哪个);接着,把乘的懒标记赋好 1 的初值,加的懒标记赋好 0 的初值;然后,判断当前是否为叶子结点(l 是不是等于 r),如果是,返回,否则递归建立左右子树直到当前节点变成叶子结点为止;最后,把左右ㄦ子的 sum 累加到父亲那里去。

主要代码:

void pushup(LL p){
    tr[p].sum=(tr[lc].sum+tr[rc].sum)%m;//把左右ㄦ子的和都累加过来。
}
void build(LL p,LL l,LL r){
    tr[p]={l,r,w[r],1,0};
    if(l==r)return;
    LL mid=l+r>>1;
    build(lc,l,mid);
    build(rc,mid+1,r);
    pushup(p);
}

Step \mathbf3update(更新函数)逻辑

首先,如果越界了,赶快返回;其次,如果当前节点范围完全覆盖需要更新的部分,使用 calc 函数更新线段树后立即返回;接着,下传懒标记;然后,更新左右ㄦ子;最后,把左右ㄦ子的 sum 累加到父亲那里去。特别强调!!!更新的范围和节点的范围是不一样的!更新的范围始终保持不变!!!

这里讲一下 calc 函数怎么写。首先,这个函数是用来维护区间和和两个懒标记的。注意:需要先乘后加才能保证精度不丢失,因为先加后乘的话,新的 t.add 就会是 t.add+\dfrac{add}{mul},如果 \dfrac{add}{mul} 不是整数,那么精度将丢失。子节点新的 t.mult.mul \times mul,新的 t.addt.add\times mul+add。这里还得计算 t.sum。我们推出一个公式来计算它:\left(x_l\times mul+add\right)+\cdots+\left(x_r\times mul+add\right),化简后就是:

\begin{align} & \left(x_l\times mul+add\right)+\cdots+\left(x_r\times mul+add\right)\\ = & \left(x_l+\cdots+x_r\right)\times mul+\left(r-l+1\right) \times add\\ = & t.sum\times mul+\left(r-l+1\right) \times add \end{align}

主要代码:

void calc(node &t/*一定要传引用,不然无法修改 tr 数组*/,LL mul,LL add){
    t.sum=(t.sum*mul+(t.r-t.l+1)*add)%m;
    t.mul=t.mul*mul%m;      //⎫
//                             先乘后加,否则精度丢失
//                            ⎬新的 mul 为 mul × m,新的 add 为 add × m + a。
    t.add=(t.add*mul+add)%m;//⎭
}
void pushdown(LL p){//下传懒标记
    calc(tr[lc],tr[p].mul,tr[p].add);
    calc(tr[rc],tr[p].mul,tr[p].add);
    tr[p].mul=1;//⎫
//                 清空懒标记
//                ⎬
    tr[p].add=0;//⎭
}
void update(LL p,LL l,LL r,LL mul,LL add){
    if(tr[p].r<l||tr[p].l>r)return;//越界
    if(l<=tr[p].l&&tr[p].r<=r){//完全覆盖
        calc(tr[p],mul,add);
        return;
    }
    pushdown(p);
    update(lc,l,r,mul,add);
    update(rc,l,r,mul,add);
    pushup(p);
}

Step \mathbf4query 查询函数

首先,如果越界了,赶快返回;接着,如果当前节点范围完全覆盖需要更新的部分,返回当前节点的 sum 值;然后,下传懒标记;最后,返回左ㄦ子的查询结果和右ㄦ子的查询结果的和。

LL query(LL p,LL l,LL r){
    if(tr[p].r<l||tr[p].l>r)return 0;//越界
    if(l<=tr[p].l&&tr[p].r<=r){//完全覆盖
        return tr[p].sum%m;
    }
    pushdown(p);
    return (query(lc,l,r)+query(rc,l,r))%m;
}

AC 代码:

#include<iostream>
#include<stdio.h>
#define LL long long
#define lc p<<1   //左ㄦ子
#define rc p<<1|1 //右ㄦ子
using namespace std;
LL m,w[100005];
struct node{
    LL l,r,sum,mul,add;
}tr[400005];//大小是 w 数组的 4 倍。
void pushup(LL p){
    tr[p].sum=(tr[lc].sum+tr[rc].sum)%m;//把左右ㄦ子的和都累加过来。
}
void calc(node &t/*一定要传引用,不然无法修改 tr 数组*/,LL mul,LL add){
    t.sum=(t.sum*mul+(t.r-t.l+1)*add)%m;
    t.mul=t.mul*mul%m;      //⎫
//                             先乘后加,否则精度丢失
//                            ⎬新的 mul 为 mul × m,新的 add 为 add × m + a。
    t.add=(t.add*mul+add)%m;//⎭
}
void pushdown(LL p){//下传懒标记
    calc(tr[lc],tr[p].mul,tr[p].add);
    calc(tr[rc],tr[p].mul,tr[p].add);
    tr[p].mul=1;//⎫
//                 清空懒标记
//                ⎬
    tr[p].add=0;//⎭
}
void build(LL p,LL l,LL r){
    tr[p]={l,r,w[r],1,0};
    if(l==r)return;
    LL mid=l+r>>1;
    build(lc,l,mid);
    build(rc,mid+1,r);
    pushup(p);
}
void update(LL p,LL l,LL r,LL mul,LL add){
    if(tr[p].r<l||tr[p].l>r)return;//越界
    if(l<=tr[p].l&&tr[p].r<=r){//完全覆盖
        calc(tr[p],mul,add);
        return;
    }
    pushdown(p);
    update(lc,l,r,mul,add);
    update(rc,l,r,mul,add);
    pushup(p);
}
LL query(LL p,LL l,LL r){
    if(tr[p].r<l||tr[p].l>r)return 0;//越界
    if(l<=tr[p].l&&tr[p].r<=r){//完全覆盖
        return tr[p].sum%m;
    }
    pushdown(p);
    return (query(lc,l,r)+query(rc,l,r))%m;
}
int main(){
    LL n,q;
    scanf("%lld%lld%lld",&n,&q,&m);
    for(int i=1;i<=n;i++){
        scanf("%lld",&w[i]);
    }
    build(1,1,n);//建树
    while(q--){
        int op;
        scanf("%d",&op);
        if(op==1){
            LL x,y,k;
            scanf("%lld%lld%lld",&x,&y,&k);
            update(1,x,y,k,0);//从节点 1 开始,范围是 x ~ y,给每一个数乘 k,给每一个数加 0。
        }else if(op==2){
            LL x,y,k;
            scanf("%lld%lld%lld",&x,&y,&k);
            update(1,x,y,1,k);//从节点 1 开始,范围是 x ~ y,给每一个数乘 1,给每一个数加 k。
        }else{
            LL x,y;
            scanf("%lld%lld",&x,&y);
            printf("%lld\n",query(1,x,y));//从节点 1 开始,范围是 x ~ y,查询区间和。
        }
    }
    return 0;
}

正文后闲话

其实,这道题你想用暴力枚举过这道题也没问题,注意超时就行。