题解 P3391 【【模板】文艺平衡树】

· · 题解

普通平衡树实现的各种操作都是基于大小关系,回想一下,它之所以能实现那么多操作,最本质的原因是它巧妙地依据元素的大小关系建树,小的元素放在左子树,大的元素放在右子树,这样中序遍历的结果就恰恰是一个递增序列。而这题要实现的变成了对给定序列的若干个区间进行翻转操作,而翻转操作实际上改变的是前后关系。这就启发我们对普通平衡树做一些修改,改为依靠元素的前后关系建树,靠前的元素在左子树,靠后的元素在右子树,然后在进行区间翻转时维护这一性质,那么最后得到的中序遍历就是整个序列经过若干次区间翻转操作后的结果。

这样做的难点就在于如何在进行区间翻转时维护整棵树的性质不变。我们先从最特殊的情况想起,如果说整棵树的高度为 2,比如只有 3 个结点:根、左儿子、右儿子,那么翻转整个区间显然相当于把根结点的左儿子和右儿子互换。如果高度大于 2,我们只需要自上而下地把每一个结点的左右子树互换。证明如下:

设整棵树的深度为 n,已知当 n=2 时结论成立,求证:假设 n=k-1 时结论成立,则 n=k 时结论也成立。设该树的高度为 k,根结点为 b,左子树的中序遍历序列为 A,右子树的中序遍历序列为 C,那么互换前的中序遍历就是 AbC,互换后是 C'bA',由于 n=k-1 时结论成立,因此 C'C 翻转的结果,B'B 翻转的结果,故 C'bA'AbC 翻转的结果,结论得证。

因此对整个序列做翻转就相当于自上而下地把树上的每一个结点的左右子树互换。但如果是对区间 [l,r] 做翻转呢?我们该如何使这段区间恰好对应着一棵完整的子树?这里 Splay 就派上用场了,我们只需要将序列里的第 l-1 个元素旋转到根结点,把第 r+1 个元素旋转到根结点的右儿子处,由于我们是按照前后关系建树的,并且 [l,r] 这段区间的元素都在 a_{l-1} 的后面、a_{r+1} 的前面,所以在旋转结束后根结点的右儿子的左子树就恰好对应着 [l,r] 这段区间,这就是所谓的 Splay 提取区间(事实上普通平衡树的删除操作就是提取区间的一个特殊情况)。那翻转区间 [l,r] 就相当于对这棵子树自上而下地进行互换操作。

至于如何找到第 l-1 和第 r+1 个元素,这相当于查找第 k 前的元素,也就是说在中序遍历上排第 k 的元素,直接套用 splay 的查找第 k 小的函数即可 。还有一点要注意,就是当 l=1r=n 时第 l-1 或第 r+1 个元素会不存在,为了防爆,我们需要事先把正无穷和负无穷插入树中。

不过,如果你每次都把要进行互换操作的子树直接从头到尾互换一遍,那时间复杂度会爆炸。这里我们要借鉴一下线段树的 lazy tag 思想,给每个结点打一个懒标记表示以它为根的子树被互换了多少次。如果我用不到它的子结点,那它的懒标记就在那放着,等到我用到它的子结点时再把标记下放。同时,由于对一颗子树互换两次等于不互换,因此懒标记只需要记录 0/1 值。

代码如下(点个赞再走吧QAQ,谢谢您!):

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define ll long long
#define fo(i,x,y) for(register int i=x;i<=y;++i)
#define go(i,x,y) for(register int i=x;i>=y;--i)
using namespace std;
inline int read(){ int x=0,fh=1; char ch=getchar(); while(!isdigit(ch)){ if(ch=='-') fh=-1; ch=getchar(); } while(isdigit(ch)){ x=(x<<1)+(x<<3)+ch-'0'; ch=getchar(); } return x*fh; }

const int N=1e5+5,inf=1e9;
int data[N],ch[N][2],fa[N],siz[N],root,tag[N],tot,a[N],cnt[N];

inline int son(int &x){return x==ch[fa[x]][1];}
inline void push_up(int &x){siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];}

inline void push_down(int x){//把x的懒标记下放到它的两个儿子 
    if(tag[x]==0) return;
    int lt=ch[x][0],rt=ch[x][1];
    tag[lt]^=tag[x];
    swap(ch[lt][0],ch[lt][1]);
    tag[rt]^=tag[x];
    swap(ch[rt][0],ch[rt][1]);
    tag[x]=0;
}

inline void rotate(int x){//rotate函数 
    int y=fa[x],z=fa[y],sx=son(x),sy=son(y),b=ch[x][!sx];
    push_down(y);push_down(x);
    //注意,rotate函数会改变以y为根和以x为根的子树的结构,因此要在旋转之前先把它们的懒标记下放
    //其余的与splay模板一致 
    if(z) ch[z][sy]=x; else root=x;fa[x]=z;
    ch[x][!sx]=y;fa[y]=x;
    ch[y][sx]=b;if(b) fa[b]=y;
    push_up(y);push_up(x);
}

inline void splay(int x,int aim){//splay函数,与模板无异 
    while(fa[x]!=aim){
        int y=fa[x],z=fa[y];
        if(z==aim) rotate(x);
        else if(son(x)^son(y)) rotate(x),rotate(x);
        else rotate(y),rotate(x);
    }
}

inline int get_kth(int k){//查找第k前的函数,与模板无异 
    k--;
    int now=root;
    while(now){
        if(siz[ch[now][0]]>k) push_down(now),now=ch[now][0];
        else if(siz[ch[now][0]]+cnt[now]>k) return now;
        else k-=siz[ch[now][0]]+cnt[now],push_down(now),now=ch[now][1];
    }
}

inline void reverse(int L,int R){//翻转[L,R]这段区间 
    int p1=get_kth(L),p2=get_kth(R+2);
    //printf("reverse(%d,%d)\n",L,R);
    //printf("p1=%d p2=%d\n",p1,p2);
    splay(p1,0);splay(p2,root);//splay提取区间 
    int p3=ch[p2][0];
    tag[p3]^=1;//打标记 
    swap(ch[p3][0],ch[p3][1]);
}

int build(int L,int R){//根据初始的序列建树 
    if(L>R) return 0;
    //printf("build(%d,%d)\n",L,R);
    int mid=(L+R)>>1,now=a[mid];
    if(now==-inf||now==inf) now=++tot;
    data[now]=a[mid];
    cnt[now]=siz[now]=1;
    if(L==R) return now;
    int lt=build(L,mid-1);
    ch[now][0]=lt;fa[lt]=now;
    int rt=build(mid+1,R);
    ch[now][1]=rt;fa[rt]=now;
    push_up(now);
    return now;
}

void get_ans(int now){//输出中序遍历 
    push_down(now);
    if(ch[now][0]) get_ans(ch[now][0]);
    if(data[now]!=-inf&&data[now]!=inf) printf("%d ",data[now]);
    if(ch[now][1]) get_ans(ch[now][1]);
}

int main(){
    int n=read(),m=read();
    a[1]=-inf;a[n+2]=inf;//为了防爆,提取将负无穷和正无穷插入树中 
    fo(i,2,n+1) a[i]=i-1;
    tot=n;
    root=build(1,n+2);
    fo(i,1,m){
        //printf("root=%d\n",root);
        //fo(j,1,tot) printf("%d:data=%d,fa=%d,ch0=%d,ch1=%d,siz=%d,tag=%d\n",j,data[j],fa[j],ch[j][0],ch[j][1],siz[j],tag[j]);
        int l=read(),r=read();
        reverse(l,r);
    } 
    //printf("root=%d\n",root);
    //fo(j,1,tot) printf("%d:data=%d,fa=%d,ch0=%d,ch1=%d,siz=%d,tag=%d\n",j,data[j],fa[j],ch[j][0],ch[j][1],siz[j],tag[j]);
    get_ans(root);
    return 0;
}
/*
-------------------------------------------------
*/