题解 P3369 【【模板】普通平衡树】
hicc0305
2018-08-09 20:53:59
## 非旋转treap了解一下
个人觉得比旋转treap好打很多,不用搞来搞去转来转去。。
非旋转treap主要通过两个操作:合并(merge),分离(divide)完成几乎这里的所有操作(除了查询排名)。
### Merge
首先我们要知道,我们不能随便合一下,我们要维护Treap的性质,也就是二叉查找树+堆。以一个随机的key值来维护堆,结点的权值val来维护二叉查找树。
我们已知棵子树的根l和r(l<r),我们比较l和r的key值,因为我们维护的是一个小根堆,如果l的key比较小,让l为合并的树的根结点,然后r往l的右子树递归。反之,r把l踢下去,r为根结点,然后l往r的左子树递归。比如说l做了根结点,r也找到了合适的位置,那么我们可以把r直接接上去了,然后把原来l上的被踢掉的部分取出来,往刚接上去的部分在重新找它的位置。。借来借去,踢来踢去,最后接好的结点越来越多。。然后就合完了。。
图:
![](https://cdn.luogu.com.cn/upload/pic/27744.png)
具体可以看代码:
```cpp
int merge(int l,int r)
{
if(!l) return r;//l没有了,全部合完了那么r接上即可,返回根r
if(!r) return l;//同理
if(tr[l].key<tr[r].key)//l的key小,l当当前子树的根
{
tr[l].ch[1]=merge(tr[l].ch[1],r);//往右子树递归
update(l);
return l;
}
else
{
tr[r].ch[0]=merge(l,tr[r].ch[0]);
update(r);
return r;
}
}
```
### Divide
就是把权值小于等于k的结点全部分在一个子树l中,大于的分在另一个子树r中。其实就是按照二分查找的方式向下找,当前结点小于k的话往右子树找,然后当前结点加入左子树。反一下就不讲了
恕我不一步一步画了:
![](https://cdn.luogu.com.cn/upload/pic/27752.png)
```cpp
void divid(int u,int x,int &l,int &r)//注意取地址符,传过来的参数会跟着更新
{
if(!u)
{
l=0,r=0;
return;
}
if(tr[u].val<=x) l=u,divid(tr[u].ch[1],x,tr[u].ch[1],r);//这里很巧妙,也就是让x的右子结点变成右子节点的右子树中,被分出来的左子树的根。好绕。。。
else r=u,divid(tr[u].ch[0],x,l,tr[u].ch[0]);//反过来一样的
update(u);
}
```
### Insert
以权值k为分界值,分出左子树l和右子树r,在把这三个东西重新合并起来,两两合并merge(merge(l,k),r),插入完成。这应该不用画图也能懂。
### Delete
以k为分界值,分出l,p,在l里,在以k-1为分界值分出,l,r,这样把k分离了出来(比k小的,比k大的,都分出去了),然后,合并k的左右子树,形成新的r,再把l,r,合并即可。
图:
![](https://cdn.luogu.com.cn/upload/pic/27765.png)
### 求k的排名
把整棵树以k-1为分界值分了,那么小于k的节点全在左树内,输出左树的size即可
### 找第k大
这不用多讲,二叉查找树基本操作,根据每个结点的size(当前子树的结点总数)递归即可
```cpp
int kth(int u,int k)
{
if(k<=tr[tr[u].ch[0]].siz) return kth(tr[u].ch[0],k);
else
{
if(k==tr[tr[u].ch[0]].siz+1) return u;
else
{
k-=tr[tr[u].ch[0]].siz+1;
return kth(tr[u].ch[1],k);
}
}
}
```
### 前驱
把整棵树以k-1为分界值分了,那么小于k的节点全在左树内,在左子树里找到排名为最后的结点权值即可,也就是排名为左子树的siz的点。当然还有合回来。。
### 后继
差不多的,自己想去
## 代码
不多解释了,结合上面应该能看懂
```cpp
#include<ctime>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int n,root,cnt=0;
struct Treeeeee
{
int ch[2],fa,val,key,siz;
}tr[100100];
void Add(int x)
{
cnt++;
tr[cnt].siz=1,tr[cnt].val=x,tr[cnt].key=rand();
tr[cnt].ch[0]=tr[cnt].ch[1]=0;
}
void update(int x)
{
tr[x].siz=tr[tr[x].ch[0]].siz+tr[tr[x].ch[1]].siz+1;
}
void divid(int u,int x,int &l,int &r)
{
if(!u)
{
l=0,r=0;
return;
}
if(tr[u].val<=x) l=u,divid(tr[u].ch[1],x,tr[u].ch[1],r);
else r=u,divid(tr[u].ch[0],x,l,tr[u].ch[0]);
update(u);
}
int merge(int l,int r)
{
if(!l) return r;
if(!r) return l;
if(tr[l].key<tr[r].key)
{
tr[l].ch[1]=merge(tr[l].ch[1],r);
update(l);
return l;
}
else
{
tr[r].ch[0]=merge(l,tr[r].ch[0]);
update(r);
return r;
}
}
int kth(int u,int k)
{
if(k<=tr[tr[u].ch[0]].siz) return kth(tr[u].ch[0],k);
else
{
if(k==tr[tr[u].ch[0]].siz+1) return u;
else
{
k-=tr[tr[u].ch[0]].siz+1;
return kth(tr[u].ch[1],k);
}
}
}
int main()
{
int l,r,p;
srand(time(NULL));
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int k,x;
scanf("%d%d",&k,&x);
if(k==1)
{
divid(root,x,l,r);
Add(x);
root=merge(merge(l,cnt),r);
}
else if(k==2)
{
divid(root,x,l,p);
divid(l,x-1,l,r);
r=merge(tr[r].ch[0],tr[r].ch[1]);
root=merge(merge(l,r),p);
}
else if(k==3)
{
divid(root,x-1,l,r);
printf("%d\n",tr[l].siz+1);
root=merge(l,r);
}
else if(k==4) printf("%d\n",tr[kth(root,x)].val);
else if(k==5)
{
divid(root,x-1,l,r);
printf("%d\n",tr[kth(l,tr[l].siz)].val);
root=merge(l,r);
}
else
{
divid(root,x,l,r);
printf("%d\n",tr[kth(r,1)].val);
root=merge(l,r);
}
}
return 0;
}
```