题解:P13780 「o.OI R2」愿天堂没有分块

· · 题解

本题还是相对简单,但是需要对一些知识有一定积累。

看到区间 \text{MEX}\text{MEX},看似是一个十分复杂的问题。但是我们知道,拥有极小 \text{MEX} 的区间仅有 O(n) 个,这就让问题重新可做起来。

什么是拥有极小 \text{MEX} 的区间呢?就是这个区间不存在任何子区间(不包含自身),使得子区间的 \text{MEX} 和他自身的 \text{MEX} 相等。

防止有的同学不知道,这里给出一份简略的证明:

求出这些区间是简单的,拿 a_l<a_r 的情况举例,就是对于一个右端点 r,找到一个尽可能大的左端点 l,使得 \text{MEX}(l,r)>a_r 就可以了。是容易使用权值线段树在 O(n \log n) 的时间内找到的。

现在才算正式开始解决本题。我们可以采取从小到大枚举 \text{MEX}=i 的方式,然后找到哪些区间满足了 \text{MEX}=i

考虑所有满足 \text{MEX}=i 的极小 \text{MEX} 区间,这些区间满足两两没有包含关系。按照左端点排序后取出相邻的两个区间 [l_1,r_1],[l_2,r_2] ,如果一个询问满足了 l_1<l \leq r<r_2,那么显然不会包含任何一个 \text{MEX}=i 的子区间,那么就算求出它的答案了。因为答案已经被求出来,这个询问后续不考虑了。

我们会调用“找到在 [l,r] 的子区间内的询问”一共 O(n) 次,每一次找到子区间内的询问可以使用线段树简易维护做到 O(n \log n)

总体时间复杂度为 O(n \log n)

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int buf[1<<10];
inline void print(int x,char ch=' '){
    if(x<0) putchar('-'),x=-x;
    int tot=0;
    do{
        buf[++tot]=x%10;
        x/=10;
    }while(x);
    for(int i=tot;i;i--)
        putchar(buf[i]+'0');
    putchar(ch);
}
const int MAXN=1e6+5;
int n,m;
int a[MAXN];

int t[MAXN<<2];
void update(int i,int l,int r,int k,int w){
    if(l==r){
        t[i]=w;
        return ;
    }
    int mid=(l+r-1)>>1;
    if(mid>=k)
        update(i<<1,l,mid,k,w);
    else
        update(i<<1|1,mid+1,r,k,w);
    t[i]=min(t[i<<1],t[i<<1|1]);
}
int query(int i,int l,int r,int L,int R){
    if(L<=l&&r<=R)
        return t[i];
    int mid=(l+r-1)>>1;
    if(R<=mid)
        return query(i<<1,l,mid,L,R);
    if(mid<L)
        return query(i<<1|1,mid+1,r,L,R);
    return min(query(i<<1,l,mid,L,mid),query(i<<1|1,mid+1,r,mid+1,R));
}
int query(int i,int l,int r,int k){
    if(l==r){
        if(t[i]>=k) return l;
        return -1;
    }
    int mid=(l+r-1)>>1;
    if(t[i<<1]<k)
        return query(i<<1,l,mid,k);
    return max(mid,query(i<<1|1,mid+1,r,k));
}

int tot;
struct node{
    int l,r,val;
    node(int a=0,int b=0,int c=0){
        l=a,r=b,val=c;
    }
    bool friend operator<(const node &a,const node &b){
        return a.l<b.l;
    }
}b[MAXN*3];

void solve1(){
    update(1,0,n,0,n);
    for(int i=1;i<=n;i++){
        update(1,0,n,a[i],i);
        int x=query(1,0,n,0,a[i]);
        if(x!=0) b[++tot]=node(x,i,query(1,0,n,x)+1);
        b[++tot]=node(i,i,a[i]==1?2:1); 
    }
    reverse(a+1,a+n+1);
    for(int i=1;i<=n;i++) update(1,0,n,i,0);
    for(int i=1;i<=n;i++){
        update(1,0,n,a[i],i);
        int x=query(1,0,n,0,a[i]);
        if(x!=0) b[++tot]=node(n-i+1,n-x+1,query(1,0,n,x)+1);
    }
}

node c[MAXN];
int ans[MAXN];
int mn[MAXN<<2];
void build(int i,int l,int r){
    if(l==r){
        mn[i]=c[l].r;
        return ;
    }
    int mid=(l+r)>>1;
    build(i<<1,l,mid);
    build(i<<1|1,mid+1,r);
    mn[i]=min(mn[i<<1],mn[i<<1|1]);
}
void find(int i,int l,int r,int L,int R,int lim,int val){
    if(mn[i]>lim) return ;
    if(l==r){
        ans[c[l].val]=val;
        mn[i]=n+1;
        return ;
    }
    int mid=(l+r)>>1;
    if(R<=mid)
        find(i<<1,l,mid,L,R,lim,val);
    else if(mid<L)
        find(i<<1|1,mid+1,r,L,R,lim,val);
    else{
        find(i<<1,l,mid,L,mid,lim,val);
        find(i<<1|1,mid+1,r,mid+1,R,lim,val);
    }
    mn[i]=min(mn[i<<1],mn[i<<1|1]);
}

void clear(int ql,int qr,int val){
    int l=lower_bound(c+1,c+m+1,node(ql,0,0))-c;
    int r=upper_bound(c+1,c+m+1,node(qr,0,0))-c-1;
    if(l<=r) find(1,1,m,l,r,qr,val);
} 

vector<pair<int,int>> seg[MAXN];
bool cmp(pair<int,int> x,pair<int,int> y){
    return x.first==y.first?x.second>y.second:x.first<y.first;
}
void solve2(){
    for(int i=1;i<=tot;i++)
        seg[b[i].val].push_back(make_pair(b[i].l,b[i].r));
    build(1,1,m);
    for(int i=1;i<=n+2;i++){
        if(seg[i].empty()){
            clear(1,n,i);
            break;
        }
        sort(seg[i].begin(),seg[i].end(),cmp);
        vector<pair<int,int>> temp;
        for(int j=seg[i].size()-1,k=n+1;j>=0;j--){
            if(k>seg[i][j].second) temp.push_back(seg[i][j]);
            k=min(k,seg[i][j].second);
        }
        sort(temp.begin(),temp.end());
        if(temp[0].second>1) clear(1,temp[0].second-1,i);
        for(int j=1;j<temp.size();j++){
            int l=temp[j-1].first+1,r=temp[j].second-1;
            if(l<=r) clear(l,r,i);
        }
        if(temp.back().first<n) clear(temp.back().first+1,n,i);
    }
}
signed main(){
    n=read(),m=read();
    for(int i=1;i<=n;i++) a[i]=read();
    solve1();
    for(int i=1;i<=m;i++){
        c[i].l=read(),c[i].r=read();
        c[i].val=i;
    }
    sort(c+1,c+m+1);
    solve2();
    for(int i=1;i<=m;i++) print(ans[i],'\n'); 
    return 0;
}