P10169 [DTCPC 2024] mex,min,max 题解

· · 题解

需要发现,\min(mex(l,r),\min(l,r))=0 是恒成立的,因为如果 \min =0,则 mex 一定不为 0,若 \min \ne 0,则 mex=0

那么我们可以把合法区间的条件容斥拆开,记 A 为满足 mex+k \ge \max 的区间个数,B\min+k \ge \max 的个数,Ck \ge \max 的个数,则由上述性质可以得到,ans=A+B-C。接下来只需分别考虑 A,B,C 如何求解即可。

对于 C,固定右端点,合法的左端点是一段连续区间,可以预处理 ST 表,二分解决,对于 B,因为极差也满足单调性(固定右端点,越往左 \min 越小,\max 越大,极差越大),所以合法的左端点也是一段区间,同样可以二分解决。

现在重点考虑 C

首先需要知道一个极短极长 mex 的东西。具体说,如果把所有 \frac{n(n+1)}{2} 个区间 mex 提取出来拍到二维平面上,划分出一些极大的 mex 相同的矩形,这样的矩形只有 O(n) 个。这是个经典结论,不会自行去看 P9970 [THUPC 2024 初赛] 套娃。

如果我们求出了这 O(n) 个矩形,把坐标信息看成 l \in [L_i,l_i],r\in[r_i,R_i] 的话,[l_i,r_i],[L_i,R_i] 就分别是极短极长的 mex 段。对于每一段,上述满足条件的 l,r,其 mex 都是相等的。又因为 \max 具有可重合并的性质,如果我们找到最靠左的 x\in[L_i,l_i],mex+k \ge\max(x,r_i),找到最靠右的 y \in [r_i,R_i],mex+k \ge \max(l_i,y),那么任意 l\in[x,l_i],r\in[r_i,y] 的区间都是满足条件的。这部分可以 O(n\log n) 求出。

但是还有一个问题,这样做会算重,如果我们把上面的 [x,l_i],[r_i,y] 重新看成一个矩形的话,再做一边矩形面积并就不会算重了。因为每个极短极长 mex 段最多只会产生一个矩形,所以复杂度是对的。

分析一下复杂度,预处理 ST 表,各种二分,矩形面积并,都是 O(n\log_2 n) 的,空间 O(n\log_2n),因为要主席树动态求区间 mex

代码略长,应该每部分分开写的很清楚。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define PI pair <int,int>
#define fir first
#define sec second
#define ll long long
using namespace std;

int n,k,tot,m;
int a[500005];
int lg[500005];
int mx[21][500005];
int mn[21][500005];
vector <int> f[500005];
vector <PI> g[500005];
struct node{int p,l,r,op;}aa[2000005];
ll ans;

inline void in(int &n){
    n=0;
    char c=getchar();
    while(c<'0' || c>'9') c=getchar();
    while(c>='0'&&c<='9') n=n*10+c-'0',c=getchar();
    return ;
}

namespace segmenttree{

    int tot;
    int rt[500005];
    int lc[10000005];
    int rc[10000005];
    int t[10000005];

    inline int ins(int u,int l,int r,int k,int x){
        ++tot;
        lc[tot]=lc[u],rc[tot]=rc[u],t[tot]=t[u];
        u=tot;
        if(l==r){t[u]=x;return u;}
        int mid=(l+r)>>1;
        if(k<=mid) lc[u]=ins(lc[u],l,mid,k,x);
        else rc[u]=ins(rc[u],mid+1,r,k,x);
        t[u]=min(t[lc[u]],t[rc[u]]);
        return u;
    }

    inline int mex(int u,int l,int r,int k){
        if(l==r) return l;
        int mid=(l+r)>>1;
        if(t[lc[u]]<k) return mex(lc[u],l,mid,k);
        else return mex(rc[u],mid+1,r,k);
    }

}
using namespace segmenttree;

inline int pre(int x,int i){
    auto p=lower_bound(f[x].begin(),f[x].end(),i);
    if(p!=f[x].begin()) return *(--p);
    return 0;
}

inline int nxt(int x,int i){
    auto p=upper_bound(f[x].begin(),f[x].end(),i);
    if(p!=f[x].end()) return *p;
    return 0;
}

inline int askmx(int l,int r){
    int len=lg[r-l+1];
    return max(mx[len][l],mx[len][r-(1<<len)+1]);
}

inline int askmn(int l,int r){
    int len=lg[r-l+1];
    return min(mn[len][l],mn[len][r-(1<<len)+1]);
}

inline void init(){
    for(int i=1;i<=n;i++) f[a[i]].emplace_back(i),rt[i]=ins(rt[i-1],0,n,a[i],i);
    for(int i=1;i<=n;i++) g[a[i]==0].push_back({i,i});
    for(int i=1;i<=n;i++){
        for(auto tmp:g[i-1]){
            int l=tmp.fir,r=tmp.sec;
            int p=pre(i-1,l);
            if(p) g[mex(rt[r],0,n,p)].push_back({p,r});
            int q=nxt(i-1,r);
            if(q) g[mex(rt[q],0,n,l)].push_back({l,q});
        }
        sort(g[i].begin(),g[i].end(),[](PI p,PI q){return p.fir==q.fir?p.sec<q.sec:p.fir>q.fir;});
        vector <PI> gg;
        int R=1e9;
        for(auto tmp:g[i]){
            if(R>tmp.sec) gg.emplace_back(tmp);
            R=min(R,tmp.sec);
        }
        g[i]=gg;
    }
    for(int i=1;i<=n;i++) mx[0][i]=mn[0][i]=a[i];
    for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
    for(int j=1;j<=lg[n];j++)
        for(int i=1;i+(1<<j)-1<=n;i++)
            mx[j][i]=max(mx[j-1][i],mx[j-1][i+(1<<(j-1))]),
            mn[j][i]=min(mn[j-1][i],mn[j-1][i+(1<<(j-1))]);

    return ;
}

inline int get1(int l,int r,int p,int x){
    int mid,ans=0;
    while(l<=r){
        mid=(l+r)>>1;
        if(askmx(mid,p)<=x) ans=mid,r=mid-1;
        else l=mid+1;
    }
    return ans;
}

inline int get2(int l,int r,int p,int x){
    int mid,ans=0;
    while(l<=r){
        mid=(l+r)>>1;
        if(askmx(p,mid)<=x) ans=mid,l=mid+1;
        else r=mid-1;
    }
    return ans;
}

inline void work(){
    for(int i=n;i>=0;i--){
        for(auto tmp:g[i]){
            int l=tmp.fir,r=tmp.sec;
            int p=pre(i,l);
            if(p) p++;
            else p=1;
            int q=nxt(i,r);
            if(q) q--;
            else q=n;
            int L=max(1,get1(p,l,r,i+k)),R=min(n,get2(r,q,l,i+k));
            if(L>R) continue;
            aa[++m]={L,r,R,1};
            aa[++m]={l+1,r,R,-1};
        }
    }

    return ;
}

namespace Segmenttree{

    int tg[2000005];
    int mnn[2000005];
    int ctt[2000005];

    inline void pushup(int u){
        mnn[u]=min(mnn[u<<1],mnn[u<<1|1]);
        ctt[u]=0;
        if(mnn[u]==mnn[u<<1]) ctt[u]+=ctt[u<<1];
        if(mnn[u]==mnn[u<<1|1]) ctt[u]+=ctt[u<<1|1];
        return ;
    }

    inline void build(int u,int l,int r){
        ctt[u]=r-l+1;
        if(l==r) return ;
        int mid=(l+r)>>1;
        build(u<<1,l,mid);
        build(u<<1|1,mid+1,r);
        return ;
    }

    inline void down(int u,int x){mnn[u]+=x;tg[u]+=x;return ;}

    inline void pushdown(int u){
        if(!tg[u]) return ;
        down(u<<1,tg[u]);
        down(u<<1|1,tg[u]);
        tg[u]=0;
        return ;
    }

    inline void updata(int u,int l,int r,int L,int R,int x){
        if(L<=l&&r<=R){down(u,x);return ;}
        pushdown(u);
        int mid=(l+r)>>1;
        if(L<=mid) updata(u<<1,l,mid,L,R,x);
        if(R>mid) updata(u<<1|1,mid+1,r,L,R,x);
        pushup(u);
    }

}
using namespace Segmenttree;

inline void solveA(){
    sort(aa+1,aa+1+m,[](node p,node q){return p.p==q.p?p.op<q.op:p.p<q.p;});
    build(1,1,n);
    int l=1;
    for(int i=1;i<=n;i++){
        while(aa[l].p<=i&&l<=m) updata(1,1,n,aa[l].l,aa[l].r,aa[l].op),l++;
        ans+=n-(mnn[1]?0:ctt[1]);
    }

    return ;
}

inline void solveB(){
    for(int i=1;i<=n;i++){
        int l=1,r=i,mid,pos=i+1;
        while(l<=r){
            mid=(l+r)>>1;
            if(askmn(mid,i)+k>=askmx(mid,i)) pos=mid,r=mid-1;
            else l=mid+1;
        }
        ans+=i-pos+1;
    }

    return ;
}

inline void solveC(){
    for(int i=1;i<=n;i++){
        int l=1,r=i,mid,pos=i+1;
        while(l<=r){
            mid=(l+r)>>1;
            if(k>=askmx(mid,i)) pos=mid,r=mid-1;
            else l=mid+1;
        }
        ans-=i-pos+1;
    }

    return ;
}

int main(){
    in(n),in(k);
    for(int i=1;i<=n;i++) in(a[i]);
    init();
    work();
    solveA();
    solveB();
    solveC();
    printf("%lld\n",ans);

    return 0;
}