带插区间k小值题解

· · 题解

静态区间k小值

首先我们先考虑静态维护区间 k 小值。

容易想到二维分块:对值域与下标都进行分块

考虑如果所有的 lr 这个区间刚好是完整的块的时候,该怎么维护?

显然我们可以先维护一个数组 a_{ij} 表示在第i个下标块、第j个值域块的数的个数。再对每一个值域块维护前缀和 sum_{ij}=\sum_{k=1}^i a_{kj}

同样的对于块内的散点也维护一个数组 b_{i,j} 表示第i块内值刚好是j的数的个数,也个前缀和数组。然后就像权值分块那样扫过去,即可 T+\frac{n}{T} 求静态区间 k 小值。

再考虑,在上面的基础上,额外给定一个数组(元素个数 T_2),求这个数组与一个刚好是整块的区间合并起来之后的第 k 小值,怎么做?

我们可以考虑把这 T_2 个数全部塞到最右边的下标块里,或者放入新建的一个额外块中,在求完 k 小值后再取出来。因为我们求 k 小值不需要用到比最右边一块还要右侧的 sum 数组,所以不用去维护前缀和,只要修改最右边一块的 sum 值就行。

因此,当 lr 不是整块的时候,我们就可以将两端的散点当成给定的数组丢进去就行了。

假设下标块长为 T_2,值域块长为 T_1

所以查询一次静态区间k小值复杂度是 O(T_2+T_1+\frac{n}{T_1}),预处理复杂度为 O(\frac{N}{T_2} \times N)

带插区间k小值

假设原序列为 A,则我们在任意时刻,任取一段区间,不考虑下标顺序,这段区间内的数集相当于是序列 A 中取出一段,删除一些数,增加一些数得到的数集。

修改操作相当于删一个插一个,插入操作不用多说。

如果删除的数和增加的数的总数为 T_3,我们就可以通过上半部分提到的方法,O(T_3+T_2+T_1+\frac{n}{T_1}),通过维护原序列 A 上的区间与增加的数合起来的区间第k小值来维护修改后序列的区间第k小值了。

我们可以维护一个链表,链表上存放一些四元组 (l,r,tag1,tag2)

对于 tag1=-1 的情况,这个四元组表示原数组的 lr 的这段区间,tag2 可被忽略。

对于 tag\not=-1 的情况,这个四元组表示一个值为 tag1 的数,lr 可被忽略。此时,若 tag2=0 表示这是插入的数;若 tag2=1 表示这是对一个插入的数做修改操作得到的数;若 tag2=3 表示这是原数列上的某个数被做修改得到的数。

我们设每次查询的时候,得到的一段区间是从原序列 A 的某一个区间 S 经过修改得到的。求 S 区间的时候可以直接暴力在链表里扫一遍。

对于所有 tag1!=-1 的情况相当于在S上加入了某个数,其中对于 tag2=3 的情况相当于还要额外在 S 中去除某个数。

对于插入或者修改操作,暴力在链表上找,如果要修改或者插入的位置刚好在某一个四元组代表的区间内,就通过修改l与r的值将这个四元组分成两个四元组,然后对于修改操作再注意一下 tag2 的标记。

所以我们便可以用 O(T_2+T_1+\frac{n}{T_1}+T_3) 的复杂度完成询问操作,O(T_3) 完成修改操作。

但是这样显然当链表内节点过多的时候会出现 T_3=N 的情况。

所以我们可以做一个优化:

当链表内的节点个数超过某个值的时候,我们就需要把链表里的数取出来变成一个新的序列,把这个序列当作 A 数组,然后重新预处理出静态区间第k小。

依旧设这个值为 T_3

每次修改/插入操作都会导致多出 2 个数。忽略常数并且最劣化考虑,假设全是修改没有查询,就是每经过 T_3 次操作就要重新得到序列并且预处理。

也就是说多次重构的总复杂度是 O(\frac{M}{T_3}\times \frac{N}{T_2} \times N)

再考虑单次操作的复杂度,假设全是询问操作,则询问总复杂度就是 O(M\times(T_2+T_1+\frac{N}{T_1}+T_3) )

所以我们取 T_3 =T_2=T_1=N^{\frac{2}{3}} 就可以得到最优复杂度 O(N^{\frac{5}{3}})

稍微优化下常数即可通过。

代码

#include<stdio.h>
#include<iostream>
#include<math.h>
#include<cstring>
using namespace std;
class block{
    public:
    int l,r,nt,pr,tag,tag2;
    int sz(){
    return r-l+1;
    }
}a[700003];
int cnt=0,head;
int split(int x,int k){
    cnt++;
    a[cnt].r=a[x].r;
    a[cnt].l=a[x].l+k;
    a[x].r=a[cnt].l-1;
    if(a[x].nt!=0) a[a[x].nt].pr=cnt;
    a[cnt].pr=x;
    a[cnt].nt=a[x].nt;
    a[x].nt=cnt;
    a[cnt].tag=a[x].tag;
    a[cnt].tag2=a[x].tag2;
    return cnt;
}
void insert(int r,int val){
    if(r==0) a[++cnt].nt=head,a[head].pr=cnt,head=cnt;
    else{
    int q=head,sum=0;
    while(sum+a[q].sz()<r){
    sum+=a[q].sz();
    q=a[q].nt;
    }
    if(r-sum!=a[q].sz()) q=a[split(q,r-sum)].pr;
    cnt++;
    if(a[q].nt) a[a[q].nt].pr=cnt;
    a[cnt].nt=a[q].nt;
    a[cnt].pr=q;
    a[q].nt=cnt;
    }
    a[cnt].l=1,a[cnt].r=1,a[cnt].tag=val,a[cnt].tag2=0;
}
void change(int r,int val){
    int sum=0,p=head;
    while(sum+a[p].sz()<r) sum+=a[p].sz(),p=a[p].nt;
    if(r-1!=sum) p=split(p,r-sum-1);
    if(a[p].sz()!=1) split(p,1);
    a[p].l=a[p].r=1;
    a[p].tag=val;
    a[p].tag2|=1;
}
int d[80003],c[80003];
char op[230003];
int x[230003],y[230003],z[230003],M,Q,I;
class static_kth_block{
public:
int sum1[907][907],sum2[907][70008],sum3[70008],sum4[907];
int T=0,T2=400;
bool vis[70008];
void change(int val,int op){
    sum3[val]+=op,sum4[val/T2+1]+=op;
}
int N=0;
int static_kth(int l,int r,int k){
    int ll=(l-1)/T+2,rr=(r-1)/T,sum=0,p=1;
    for(int i=l;i<=(ll-1)*T;i++) if(i<=N)sum3[c[i]]++,sum4[(c[i])/T2+1]++;
    for(int i=rr*T+1;i<=r;i++) if(i<=N) sum3[c[i]]++,sum4[(c[i])/T2+1]++;
    while(sum+sum1[rr][p]-sum1[ll-1][p]+sum4[p]<k) {
    sum+=sum1[rr][p]-sum1[ll-1][p]+sum4[p],p++;
    }
    p=(p-1)*T2;
    while(sum+sum2[rr][p]-sum2[ll-1][p]+sum3[p]<k) sum+=sum2[rr][p]-sum2[ll-1][p]+sum3[p],p++;
    for(int i=l;i<=(ll-1)*T;i++) if(i<=N) sum3[c[i]]--,sum4[c[i]/T2+1]--;
    for(int i=rr*T+1;i<=r;i++) if(i<=N) sum3[c[i]]--,sum4[c[i]/T2+1]--;
    return p;
}
int tot=0,val[70003];
void init(int n){
    if(N)
    for(int i=1;i<=(N-1)/T+1;i++){
    for(int j=1;j<=tot;j++)
    sum2[i][val[j]]=0;  
    for(int j=1;j<=(70000/T2+1);j++)
    sum1[i][j]=0;
    }
    N=n;
    T=pow(n,0.667);
    if(T<=1) T=2;
    memset(vis,0,sizeof(vis));
    tot=0;
    for(int i=1;i<=n;i++){
    sum1[(i-1)/T+1][c[i]/T2+1]++;
    sum2[(i-1)/T+1][c[i]]++;    
    if(!vis[c[i]]) vis[c[i]]=1,val[++tot]=c[i]; 
    }
    for(int i=1;i<=(n-1)/T+1;i++){
    for(int j=1;j<=tot;j++)
    sum2[i][val[j]]+=sum2[i-1][val[j]]; 
    for(int j=1;j<=(70000/T2+1);j++)
    sum1[i][j]+=sum1[i-1][j];   
    }
}
} skb; 
/* 
*/
int n; 
#define kth(a,b,c) skb.static_kth(a,b,c) 
#define ins(x) skb.change(x,1)
#define del(x) skb.change(x,-1)
int top=0,top2=0,de[70003],in[70003];
int ask(int l,int r,int k){ 
    int ans=0;
    int sum=0,p=head;
    while(sum+a[p].sz()<l) sum+=a[p].sz(),p=a[p].nt;
    int ll=70009,rr=-1;
    if(a[p].tag==-1) ll=a[p].l+l-sum-1;
    int q=p;
    while(sum+a[q].sz()<r){
    if(a[q].tag==-1) {if(ll==70009) ll=a[q].l;rr=a[q].r;}
    else{
    ins(a[q].tag);
    in[++top2]=a[q].tag;
    if(a[q].tag2==3 && rr!=-1) del(c[++rr]),de[++top]=c[rr];
    }
    sum+=a[q].sz();
    q=a[q].nt; 
    }
    if(a[q].tag==-1) {if(ll==70009) ll=a[q].l;rr=a[q].l+r-sum-1;}
    else{
    ins(a[q].tag);
    in[++top2]=a[q].tag;
    if(a[q].tag2==3 && rr!=-1) del(c[++rr]),de[++top]=c[rr];
    }
    if(ll>rr) {
    ll=1,rr=1;
    del(c[1]);
    de[++top]=c[1];
    }
    ans=kth(ll,rr,k);
    for(int i=1;i<=top;i++) ins(de[i]);
    top=0;
    for(int i=1;i<=top2;i++) del(in[i]);
    top2=0;
    return ans;
}
int tot=0;
void merge(){
    tot=0;
    for(int i=head;i;i=a[i].nt)
    for(int j=a[i].l;j<=a[i].r;j++) d[++tot]=(a[i].tag==-1)?c[j]:a[i].tag;
    swap(c,d);
    skb.init(tot);
    cnt=0;
    head=1;
    a[++cnt]={1,tot,0,0,-1,2};
}
int T,m;
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    scanf("%d",&c[i]);      
    scanf("%d",&m);
    for(int i=1;i<=m;i++){
    op[i]=getchar();
    while(op[i]!='I' && op[i]!='M' && op[i]!='Q') op[i]=getchar();
    if(op[i]=='I'){I++;
    scanf("%d%d",&x[i],&y[i]);
    }
    if(op[i]=='M'){M++;
    scanf("%d%d",&x[i],&y[i]);
    }
    if(op[i]=='Q'){Q++;
    scanf("%d%d%d",&x[i],&y[i],&z[i]);
    }
    }
    skb.init(n);
    cnt=0;
    head=1; 
    a[++cnt]={1,n,0,0,-1,2};    
    int temp=0;
    T=pow(n+m,0.667);
    if(T<=1) T=2; 
    for(int i=1;i<=m;i++){
    if(op[i]=='I'){
    x[i]^=temp;
    y[i]^=temp;
    insert(x[i]-1,y[i]);
    }
    if(op[i]=='M'){
    x[i]^=temp;
    y[i]^=temp;
    change(x[i],y[i]);
    }
    if(op[i]=='Q'){
    x[i]^=temp;
    y[i]^=temp;
    z[i]^=temp;
    temp=ask(x[i],y[i],z[i]);
    printf("%d\n",temp);
    }
    if(cnt>=T)
    merge();
    }
}