题解 P6136 【【模板】普通平衡树(数据加强版)】

· · 题解

看到题目的时候我一阵窃喜,因为上个月我刚刚研究了vector水平衡树怎么扩展,奋战两天后得到这样的结果。(建议先看完此题解,再看此博客)

但是因为这道题的n和m都较小,所以里面很多优化是不需要的,所以用块状vector做此题拥有代码简练,通俗易懂且速度快,空间小的优点。(可能是理解起来最简单,空间最小的做法)

由于m比n大得多,所以我下面用m代替原题中的n+m,同时为了方便,用size表示某vector中元素个数。

首先你要知道,vector的insert和erase是O(size)的,但常数特别小(如果卡满还是比bitset慢一半,随机数据好一点)(所以大概可以写成是O(size/w)?),然后还要知道vector怎么做普通平衡树(原版)。

这里假设大家都知道了(不知道的也不会来做这个题吧),我们发现这里m=1e6,直接套用那个做法时间复杂度是O(m^2)的,虽然它的常数特别小,但也没有小到这个程度,妥妥的TLE。

然后我们尝试优化:我们发现在这个做法中,只有insert和erase是O(size) 的,而其他操作都是O(log)或O(1)的,极度不平衡。众所周知,我们有一个处理复杂度不平衡情况的做法是分块。所以我们可以尝试用分块优化这个暴力。

但是由于元素会实时改变所以没法直接分块,要用块状链表之类的东西,所以我把这个东西叫块状vector。

具体而言:把一个大vector拆成若干个小vector,要保证这些vector都是有顺序的,即下一个vector中的元素都不小于上一个vector中的元素,同时一个vector内部也要保证有序。

首先考虑重构操作,由于m较小,所以我们甚至不需要真的写块状链表,每进行t次操作后可以暴力全局重构,每次重构时间复杂度是O(m)的,所以这一部分是O(m^2/t)的,此外我们要决定重构块长,记为len,表示重构结束每个vector中有len个元素。

uodate:我已经找到了更好的写法方便的重构方法:不管t的限制,只有当某一个块被删成空块或长度大于2*len的时候再重构(更靠近块状链表的做法),这样有一个好处是数据弱的时候可以大幅度减少重构次数,最坏情况下也是每len次操作就要重构一次,所以不会造成负优化,由于数据较水,我把len开的很小也没有出现大量重构的现象,于是这个做法现在成了最优解,这里是提交记录。

然后就是平衡树部分的操作了。

首先要写一个work()函数,表示我们要找到某个元素在第几个块,具体实现可以从第一个vector开始每次比较这个vector开头的元素和我当前要查询的元素哪个更大,遇到第一个比它大的,就返回上一个vector。

插入:先调用work(),然后正常的二分insert就好了。

删除:先调用work(),然后正常的二分找到它的位置然后erase。

查排名:先找到它在哪一块,然后答案就是它前面所有的vector的size和加上当前vector中比它小的元素个数,最后+1.(注意这里排名指的是比它小的元素的个数+1,它自己不一定真的出现某个在vector中)

查第k小:从第一个vector开始,每次判断k与这个vetor的size大小,如果k>size,说明它不在这个vector内,将k减size,然后继续查找下一个vector,否则这个元素就在这个vector内,直接返回vector中第k个元素就好了。

求x的前驱:先找到x的排名k,然后返回第k-1小。

求x的后继:先找到x+1的排名k,然后返回第k小。

分析复杂度:由于每块长为len,所以共有m/len个块,所以我们的work()函数时间复杂度是 O(m/len)的,其他6个函数本质上都需要调用一次work()函数所以都要算上,此外查排名有一个O(log len) 的二分,前驱后继(为了好写)调用了排名函数所以也要带这个log,insert和erase除了带这个log还额外附加一个O((len+t)/w),其中w是个大常数,加t的原因是可能多次只向一个vector插入会使它产生无法忽略的变长。

所以最慢的操作依旧为insert,复杂度O(m/len + log len + (len+t)/w);

算上重构的时间,总的时间复杂度为O( m^2/t + m( m/len +log len +(len+t)/w ) ),这里取 len=t=1e4 即可通过本题。

update:由于今天发现vector的常数没有我想象的那么小,所以把len和t开的再小一点可能会更好。

但事实上此做法还有较高的拓展性,具体做法我已经放在最开头的链接的博客里了,用合理的方法优化此做法能在m=1e7的时候仍不输平衡树,而且可能也有很多我不知道的常数优化,但由于比较复杂,我个人也没有写过块状链表,所以我没有代码实现,那么欢迎大家实现一下那个做法。

按t重构的话还有一个细节:如果t>len,那么有可能会在两次重构期间删完某个vector,边界条件变得难以处理,一个实现简单的解决方法就是删到某 vector为空时直接重构。此外最后一个vector由于长度不足len,所以无论t和len怎么设,都有可能被删完,这时特判一下,如果最后一个vector被删完了直接把vector的数量-1就好了。

最后贴上简单易懂的代码,有一些实现细节看代码吧:(按t重构)

#include<bits/stdc++.h>
#define ll long long
#define find() lower_bound(v[i].begin(),v[i].end(),x)
#define gc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
using namespace std;
inline int read()
{
    static char buf[1<<21],*p1,*p2;
    int f=0,c=1; char ch=gc();
    while(ch<'0'||ch>'9') { if(ch=='-') c=-1; ch=gc(); }
    while(ch>='0'&&ch<='9') { f=f*10+ch-'0'; ch=gc(); }
    return c*f;
}
const int t=10000,len=10000;//每t次操作重构一次,每个vector的长度为len 
int a[1234567];
vector <int> v[1000];
int op,x,y,n,m,sz,tot,cnt,lst,ans;//sz记有几个vector 
void init(int flag=0)//重构 
{
    if(!flag) tot=0; cnt=0; 
    for(int i=1;i<=sz;++i)
    {
        int m=v[i].size();
        for(int j=0;j<m;++j)
            a[++tot]=v[i][j];
        v[i].clear();
    }
    sz=tot/len+(tot%len!=0);
    sz=max(sz,1);//特判一下没有元素的情况 
    for(int i=1;i<sz;++i)
        for(int j=0;j<len;++j) 
            v[i].push_back(a[++cnt]);
    while(cnt<tot) v[sz].push_back(a[++cnt]);
}
int work(int x)//定位x在那一块 
{
    int i=1;
    while(i<=sz&&v[i][0]<x) ++i;
    --i;  
    return max(1,i);//有可能x是最小的元素会返回0,特判一下 
}
void insert(int x) { int i=work(x); v[i].insert(find(),x); }
void erase(int x) 
{ 
    int i=work(x); 
    if(find()!=v[i].end()) v[i].erase(find()); //有可能它就是下一个块开头的元素,特判一下
    else v[i+1].erase(v[i+1].begin());
    if(v[sz].size()==0) sz=max(sz-1,1);//最后一块由于不完整,是有可能被删完的 
}
int rnk(int x) 
{ 
    int i=work(x),ans=0;
    for(int j=1;j<i;++j) ans+=v[j].size();
    return ans+(find()-v[i].begin())+1;
}
int kth(int x)
{
    for(int i=1;i<=sz;++i)
    {
        if(x<=v[i].size()) return v[i][x-1]; 
        x-=v[i].size();
    }
}
int pre(int x) { return kth(rnk(x)-1); }
int nxt(int x) { return kth(rnk(x+1)); }
int main()
{
//  freopen("test.in","r",stdin);
//  freopen("test.out","w",stdout);
    n=read(); m=read();
    for(int i=1;i<=n;++i) a[i]=read();
    sort(a+1,a+n+1);
    tot=n; init(1);//开始的时候由于n>len,  所以这里要重构 
    while(m--)
    {
        if(m%t==0) init();//重构 
        op=read(); x=read()^lst;
        if(op==1) insert(x);
        if(op==2) erase(x);
        if(op==3) lst=rnk(x);
        if(op==4) lst=kth(x);
        if(op==5) lst=pre(x);
        if(op==6) lst=nxt(x);
        if(op>=3) ans^=lst;
    }
    cout<<ans;
    return 0;
}