题解 P5889 【跳树】

· · 题解

P5889 跳树

首先大致读一下题,不难发现,每个点的二进制编码就是1号节点到这个节点的路径

最高的1代表根节点,之后的0代表左儿子,1代表右儿子,一直走到最低位,就是这个数对应的节点。

$10$ -> $00001010

1开始,后面是“010”,也就是“左右左”,画出树来,从1开始,左儿子——2,右儿子——5,左儿子——10

显然,由任意节点开始到它孩子的路径都可以按 “左01” 的方式表示。

所以,对于树上任意一个编号为s的节点,
操作1:s=s<<1
操作2:s=(s<<1)+1
操作3:s=s>>1

拓展一步,我们可以表示出
s的n辈祖先:s>>n
and
s的n代孩子:(s<<n)+ s到这个儿子 的路径
因为s表示 1s 的路径,我们把它左移n位(显然s到它的n代孩子要走n步),再加上s到这个孩子要走的路径,得到的也就是 1到孩子 的路径

我们考虑用线段树做这道题,那么要维护的就是
跳过这段区间后,s会如何改变

对每个区间定义如下:

struct node(){
    int fstl;   //s最高会跳到它的第几辈祖先上
    int l;      //跳到最高祖先后,s会跳到这个祖先的第几代孩子上(终点与最高祖先的深度差)
    ll num;     //祖先-终点 的路径
};

(max(1,s>>fstl)<<l)+num

表示跳过这个区间后s的位置(取max的意义是避免左移运算使代表根节点的1消失)

所以,我们可以把每个操作表示为:
1:fstl=1;
2:l=1;
3:l=1,num=1;

那么,剩下最大的敌人就是区间合并了。
s先跳过区间a,再跳过区间b的过程是这样的:

s>>=a.fstl; //由于这里s并不是真正的点,我们不把它与1取max
s<<=a.l;
s+=a.num;

s>>=b.fstl; //注意这一步,a.num被b.fstl“去”一部分
s<<=b.l;
s+=b.num;

由于代码中要多次用到区间合并,我们不妨直接重载运算符

node operator + (const node &b) const{

a.num在这里有一部分被b“抹去”了,所以我们可以根据它来分类讨论。

设区间a和区间b合并后的结果为区间ans

    node ans;

ab中有一个为空,ans=另一个

    if(!b.fstl&&!b.l){
        return *this;
    }
    if(!fstl&&!l){
        return b;
    }

ab非空,a.l>b.fstl,即a.numans产生了影响:

    if(l>b.fstl){
        ans.fstl=fstl;        //b.fstl全部小于a.l,也就是说不能比跳a区间时达到的最高位置更高
        ans.l=l-b.fstl+b.l;   //剩余的a.num长度加上b.num的长度
        ans.num=((num>>b.fstl)<<b.l)+b.num;
    }

ab非空,a.l<=b.fstl,即a.numb.fstl全部“抹去”,对ans无影响:

    else{
        ans.fstl=fstl+b.fstl-l;//原有的部分加上b.fstl的剩余部分
        ans.l=b.l;
        ans.num=b.num;
    }
    return ans;
}

然后套线段树模板即可。

Code

//码风清奇见谅
#include<iostream>
#include<cstdio>
using namespace std;
#define N 500007
typedef long long ll;
//线段树常规def
#define ls (rt<<1)
#define rs ((rt<<1)+1)
#define mid (l+r>>1)
#define lson rt<<1,l,l+r>>1
#define rson rt<<1|1,(l+r>>1)+1,r

ll read(){
    //put your 快读 here
}
template<typename Tp>
void print(Tp a){
    //put your 快写 here
}

struct node{//定义区间
    int fstl,l;
    ll num;
    node operator + (const node &b) const{//合并计算
        if(!b.fstl&&!b.l){
            return *this;
        }
        if(!fstl&&!l){
            return b;
        }
        node ans;
        if(l>b.fstl){
            ans.fstl=fstl;
            ans.l=l-b.fstl+b.l;
            ans.num=((num>>b.fstl)<<b.l)+b.num;
        }
        else{
            ans.fstl=fstl+b.fstl-l;
            ans.l=b.l;
            ans.num=b.num;
        }
        return ans;
    }
};

node tree[N*5];//线段树
ll n,m,sz,now;

void build(int rt,int l,int r){//建树
    if(l==r){
        int x=read();
        if(x==1){
            tree[rt].l=1;
        }
        else if(x==2){
            tree[rt].l=1;
            tree[rt].num=1;
        }
        else{
            tree[rt].fstl=1;
        }
        return;
    }
    build(lson);
    build(rson);
    tree[rt]=tree[ls]+tree[rs];
}
void upd(int rt,int l,int r,int pos,int x){//替换
    if(l==r){
        tree[rt]=(node){0,0,0};
        if(x==1){
            tree[rt].l=1;
        }
        else if(x==2){
            tree[rt].l=1;
            tree[rt].num=1;
        }
        else{
            tree[rt].fstl=1;
        }
        return;
    }
    if(pos>mid){
        upd(rson,pos,x);
    }
    else{
        upd(lson,pos,x);
    }
    tree[rt]=tree[ls]+tree[rs];
}
node query(int rt,int l,int r,int s,int t){//查询
    node ans=(node){0,0,0},lans,rans;
    if(l>t||r<s){
        return ans;
    }
    if(s<=l&&r<=t){
        return tree[rt];
    }
    lans=query(lson,s,t);
    rans=query(rson,s,t);
    return lans+rans;
}

int main(){
    sz=read(),n=read(),m=read();
    build(1,1,n);
    int opt,x,y;
    ll s;//这里用int第10个点会爆
    node r;
    while(m--){
        opt=read();
        if(opt&1){
            s=read(),x=read(),y=read();
            r=query(1,1,n,x,y);

            s=(max(1ll,s>>r.fstl)<<r.l)+r.num;
            print(s),putchar('\n');
        }
        else{
            x=read(),y=read();
            upd(1,1,n,x,y);
        }
    }
    return 0;
}

萌新第一篇题解,有问题请务必指出,谢谢!