题解 P3372 【【模板】线段树 1】

合451518

2017-08-04 11:04:48

题解

标准的线段树模板题——区间更新。

主程序——简单判断

每次读入第一个数,若是1更新,若是2直接输出。(不讲了)

关键问题——五个函数

up(int p)

这个函数主要用于更新当前节点p的值。

节点p的子节点为2p和2p+1.

s[p]=s[2p]+s[2p+1]

down(int p, int l, int r)

首先,储存节点p的值的s[p]指a[l]+...+a[r]的数字和。

这个函数用于传递add[]数组的量。

由于是区间更新,我们不能像单点更新一样一个一个更新,而是要利用add[]来存放更新的量。

这个函数1. 下传add[p]: add[2p]+=add[p];add[2p+1]+=add[p];

  1. 修改子节点2p,2p+1的值:s[2p]+=(m-l+1)add[p];s[2p+1]=(r-m)add[p];

其中m=(l+r) div 2;

这里注意一下:add[p]指l-r这个区间当中:任意add[i]都应该加add[p].其中l<=i<=r.

所以我们在更新节点时,还要用add[p]乘以它区间的长度,也就是s[2p]=(m-l+1)*add[p]了。

最后将add[p]=0;

build(int p, int l, int r)

p,l,r的意思同上。

从p开始建树。

当l==r时:到了叶子节点,读入其值。

注意:应在开始之前:add[p]=0(“洗碗”操作——归零)

然后定义m=(l+r)/2;分成两路更新,最后up(p);

update(int p, int l, int r, int x, int y, int v)

p,l,r意思同上,x,y,v表示[x,y]之间的每一个点都需要+v.

注意:add[p]不需要着急传下去,我们在查询query时也会下传。

  1. 当x<=l and r<=y 时,说明[l,r]区间 包含于[x,y]中,可以直接更新。

所以:add[p]+=v;s[p]+=(r-l+1)v;return;

  1. 正常更新。因为这是区间更新,我们要像线段树点更新中的查询一样,分成两路进行更新。

if x<=m then update(...) //更新2*p节点

if m<y then update(...) //更新2p+1节点

这样做是因为有可能出现x<=m<=y的情况。[x,y]被m分开,所以分子节点更新。

注意:在更新前应先下传add:down(p,l,r).

防止信息留在p的父节点p/2上。

最后更新up(p);

query(int p, int l, int r, int x, int y)

查询[x,y]区间和。

  1. if x<= l && r<=y return s[p];

如果[l,r]包含于[x,y],可以直接实用[l,r]区间的数字和,直接返回。

  1. 分成两路查询。

在查询之前,应再次down(p,l,r),理由同update.

然后定义m=(l+r)/2;

if x<=m ret+=query... //获得对于2*p及其子节点进行查询的值

if m<y ret+=query... //获得对2p+1及其子节点进行查询的值

最后返回ret.

注意:

  1. 数可能很大,数组定义为long long类型(pascal中的int64);

  2. 一定要在update和query中down(p,l,r),及时下传add[p]。

下面给完整代码:

//Luogu 3372
#include<iostream>
#include<cstdio>
using namespace std;

const int MAXN=100000;

int n, q;
long long add[4*MAXN+10], s[4*MAXN+10];

void up(int p)
{
    s[p]=s[2*p]+s[2*p+1];
    return;
}

void down(int p, int l, int r)
{
    if (add[p]) {
        add[2*p]+=add[p];
        add[2*p+1]+=add[p];
        int m=(l+r)/2;
        s[2*p]+=(m-l+1)*add[p];
        s[2*p+1]+=(r-m)*add[p];
        add[p]=0;
    }
    return;
}

void build(int p, int l, int r)
{
    add[p]=0;
    if (l==r) {
        scanf("%d",&s[p]);
        return;
    }

    int m=(l+r)/2;
    build(2*p, l, m);
    build(2*p+1, m+1, r);
    up(p);
    return;
}

void update(int p, int l, int r, int x, int y, int v)
{
    if (x<=l && r<=y) {
        add[p]+=v;
        s[p]+=(r-l+1)*v;
        return;
    }

    down(p, l, r);
    int m=(l+r)/2;
    if (x<=m) update(2*p, l, m, x, y, v);
    if (m<y) update(2*p+1, m+1, r, x, y, v);
    up(p);
    return;
}

long long query(int p, int l, int r, int x, int y)
{
    if (x<=l && r<=y) {
        return s[p];
    }

    down(p, l, r);
    int m=(l+r)/2;
    long long ret=0;
    if (x<=m) ret+=query(2*p, l, m, x, y);
    if (m<y) ret+=query(2*p+1, m+1, r, x, y);
    return ret;
}

int main()
{
    cin >> n >> q;
    build(1,1,n);
    for (int i=1; i<=q; i++) {
        int tmp, a, b, x;
        scanf("%d",&tmp);
        if (tmp==1) {
            scanf("%d%d%d",&a,&b,&x);
            update(1, 1, n, a, b, x);
        }
            else {
                scanf("%d%d",&a,&b);
                cout << query(1, 1, n, a, b) << endl;
            }
    }
    return 0;
}