题解 P5494 【模板】线段树分裂

· · 题解

模板题,要求维护权值线段树合并,分裂,单项插入,区间查询,找第 k 小。

权值线段树的节点 [l,r] 维护的信息有,他的左孩子和右孩子,区间内元素个数。

单项插入

权值线段树基本操作,此题不用离散化。

void update(int const &p,int const &v,int &x,int const &l=1,int const &r=n){
    if(!x) x=newnode();//动态开点
    tr[x].v+=v;
    if(l==r) return;
    int mid=(l+r)>>1;
    if(p<=mid) update(p,v,tr[x].ls,l,mid);
    else update(p,v,tr[x].rs,mid+1,r);
}

区间查询

也是基本操作,注意的是动态开点线段树在找到空节点后就可以返回了。

long long query(int const &pl,int const &pr,int const &x,int const &l=1,int const &r=n){
    if(!x) return 0;
    if(pl==l&&pr==r) return tr[x].v;
    int mid=(l+r)>>1;
    if(pr<=mid) return query(pl,pr,tr[x].ls,l,mid);
    else if(pl>mid) return query(pl,pr,tr[x].rs,mid+1,r);
    else return query(pl,mid,tr[x].ls,l,mid)+query(mid+1,pr,tr[x].rs,mid+1,r);
}

找第 k

在线段树上二分,如果左孩子的元素个数大于等于 k,说明第 k 小在左子树内;否则,在右子树内。

int kth(long long const &p,int const &x,int const &l=1,int const &r=n){
    if(tr[x].v<p) return -1;
    if(l==r) return l;
    int mid=(l+r)>>1;
    if(tr[tr[x].ls].v>=p) return kth(p,tr[x].ls,l,mid);
    else return kth(p-tr[tr[x].ls].v,tr[x].rs,mid+1,r);
}

线段树合并

线段树合并有两种写法,这里写的是将一棵树合并到另一颗上的写法。对于 t 中没有的节点不用遍历;对于 p 中没有但 t 中有的,可以直接把该节点挂在 p 上。(p,t 的含义见题目)

void merge(int &x1,int &x2,int const &l=1,int const &r=n){
    if(!x2) return;
    if(!x1){x1=x2;x2=0;return;}
    tr[x1].v+=tr[x2].v;
    int mid=(l+r)>>1;
    merge(tr[x1].ls,tr[x2].ls,l,mid);
    merge(tr[x1].rs,tr[x2].rs,l,mid);
    delnode(x2);//垃圾回收
}

线段树分裂

终于到正题了,线段树分裂是线段树合并的逆操作。对于待分裂区域与其他区域的公有节点,复制一份;对于独有节点,直接拿过来挂上去;最后记得 pushup

void split(int const &pl,int const &pr,int &x1,int &x2,int const &l=1,int const &r=n){
    if(pl==l&&pr==r){
        x2=x1;
        x1=0;
        return;
    }
    if(!x2) x2=newnode();
    int mid=(l+r)>>1;
    if(pr<=mid) split(pl,pr,tr[x1].ls,tr[x2].ls,l,mid);
    else if(pl>mid) split(pl,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
    else split(pl,mid,tr[x1].ls,tr[x2].ls,l,mid),split(mid+1,pr,tr[x1].rs,tr[x2].rs,mid+1,r);
    tr[x1].v=tr[tr[x1].ls].v+tr[tr[x1].rs].v;//pushup
    tr[x2].v=tr[tr[x2].ls].v+tr[tr[x2].rs].v;
}

可回收垃圾(

新建节点和删除节点,开个垃圾桶,回收废弃节点。

int newnode(){
    if(tp!=lj)return *--tp;
    else return ++cnt;
}
void delnode(int &x){
    *tp++=x;
    tr[x].v=tr[x].ls=tr[x].rs=0;
    x=0;
}

时间复杂度证明

完整代码

link