题解:P13825 线段树 1.5

· · 题解

操作是非常简单的区间加和区间求和,但是问题就出在于这个序列的长度太大了。但我们发现那些没有修改过或者询问过的元素对答案没有任何的作用,所以我们考虑将修改区间和询问区间的端点离散化,然后我们就在离散化后的序列上建立线段树。

由于在离散化后的序列上,每一个元素不一定是相邻的,也就是说当前节点覆盖的区间长度不一定是原区间长度,所以要将线段树上的每个节点所覆盖区间的左端点和右端点记录下来的。

同理,两个区间合并的时候,它们之间同样也有可能有元素,因此我们要维护它们之间那段区间的长度以及区间和(下文称这个区间为中间区间)。

这个样子在区间修改和区间求和的时候,如果该节点的左儿子节点和右儿子节点我们都遍历过了,那么该节点所记录的中间区间也要修改,否则就不用修改。

因此,我们的线段树需要维护以下几个东西:

剩下就是线段树操作了。

注意要开 unsigned long long

代码部分:

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ull unsigned long long
const int N=1e5+10,M=2e5+10;
int n,m,ls,tot,num[M];//离散化后的序列长度得开两倍
struct opt{
    int type,l,r,x;
}a[N];
struct lsh{
    int x,y,z;
}s[M];//离散化 
struct tree{
    ull l,r,sum,len,midsum,plus;
}tr[M<<2];
bool cmp(lsh aa,lsh bb) {return aa.x<bb.x;}
int f(int x) {return x*(x+1)/2;}
void build(int k,int l,int r)
{
    if(l==r)
    {
        tr[k]=(tree){num[l],num[l],num[l],0,0,0};
        return;
    }
    int mid=(l+r)>>1,k1=k<<1,k2=k<<1|1;
    build(k1,l,mid);build(k2,mid+1,r);
    tr[k].l=tr[k1].l,tr[k].r=tr[k2].r;
    tr[k].len=(tr[k2].l-tr[k1].r-1);
    tr[k].midsum=f(tr[k2].l-1)-f(tr[k1].r);
    tr[k].sum=tr[k1].sum+tr[k2].sum+tr[k].midsum;
}
void pushdown(int k)
{
    if(!tr[k].plus) return;
    int k1=k<<1,k2=k<<1|1,pl=tr[k].plus;
    tr[k1].sum+=(tr[k1].r-tr[k1].l+1)*pl;
    tr[k1].midsum+=tr[k1].len*pl;
    tr[k1].plus+=pl;
    tr[k2].sum+=(tr[k2].r-tr[k2].l+1)*pl;
    tr[k2].midsum+=tr[k2].len*pl;
    tr[k2].plus+=pl;
    tr[k].plus=0;
}
void change(int k,int l,int r,int x,int y,int v)//区间加
{
    if(x<=l&&r<=y)
    {
        tr[k].sum+=(tr[k].r-tr[k].l+1)*v;
        tr[k].midsum+=tr[k].len*v;
        tr[k].plus+=v;
        return;
    }
    int mid=(l+r)>>1,k1=k<<1,k2=k<<1|1;
    bool lc=false,rc=false;
    pushdown(k);
    if(x<=mid) change(k1,l,mid,x,y,v),lc=true;
    if(y>mid) change(k2,mid+1,r,x,y,v),rc=true;
    if(lc&&rc) tr[k].midsum+=tr[k].len*v;
    tr[k].sum=tr[k1].sum+tr[k2].sum+tr[k].midsum;
}
ull query(int k,int l,int r,int x,int y)//区间求和
{
    if(x<=l&&r<=y) return tr[k].sum;
    int mid=(l+r)>>1,k1=k<<1,k2=k<<1|1;ull sum=0;
    bool lc=false,rc=false;
    pushdown(k);
    if(x<=mid) sum+=query(k1,l,mid,x,y),lc=true;
    if(y>mid) sum+=query(k2,mid+1,r,x,y),rc=true;
    if(lc&&rc) sum+=tr[k].midsum;
    return sum;
}
signed main()
{
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<=m;i++)
    {
        int type,l,r,x=0;
        scanf("%lld%lld%lld",&type,&l,&r);
        if(type==1) scanf("%lld",&x);
        a[i]=(opt){type,l,r,x};
        s[++ls]=(lsh){l,i,0},s[++ls]=(lsh){r,i,1};
    }
    sort(s+1,s+1+ls,cmp);
    for(int i=1;i<=ls;i++)
    {
        tot+=(s[i].x!=s[i-1].x);
        if(s[i].z) a[s[i].y].r=tot;
          else a[s[i].y].l=tot;
        num[tot]=s[i].x;
    }
    build(1,1,tot);
    for(int i=1;i<=m;i++)
    {
        int type=a[i].type,l=a[i].l,r=a[i].r,x=a[i].x;
        if(type==1) change(1,1,tot,l,r,x);
        if(type==2) cout<<query(1,1,tot,l,r)<<'\n';
    }
    return 0;
}