P10169 [DTCPC 2024] mex,min,max 题解
一
需要发现,
那么我们可以把合法区间的条件容斥拆开,记
对于
现在重点考虑
二
首先需要知道一个极短极长
如果我们求出了这
但是还有一个问题,这样做会算重,如果我们把上面的
三
分析一下复杂度,预处理 ST 表,各种二分,矩形面积并,都是
代码略长,应该每部分分开写的很清楚。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define PI pair <int,int>
#define fir first
#define sec second
#define ll long long
using namespace std;
int n,k,tot,m;
int a[500005];
int lg[500005];
int mx[21][500005];
int mn[21][500005];
vector <int> f[500005];
vector <PI> g[500005];
struct node{int p,l,r,op;}aa[2000005];
ll ans;
inline void in(int &n){
n=0;
char c=getchar();
while(c<'0' || c>'9') c=getchar();
while(c>='0'&&c<='9') n=n*10+c-'0',c=getchar();
return ;
}
namespace segmenttree{
int tot;
int rt[500005];
int lc[10000005];
int rc[10000005];
int t[10000005];
inline int ins(int u,int l,int r,int k,int x){
++tot;
lc[tot]=lc[u],rc[tot]=rc[u],t[tot]=t[u];
u=tot;
if(l==r){t[u]=x;return u;}
int mid=(l+r)>>1;
if(k<=mid) lc[u]=ins(lc[u],l,mid,k,x);
else rc[u]=ins(rc[u],mid+1,r,k,x);
t[u]=min(t[lc[u]],t[rc[u]]);
return u;
}
inline int mex(int u,int l,int r,int k){
if(l==r) return l;
int mid=(l+r)>>1;
if(t[lc[u]]<k) return mex(lc[u],l,mid,k);
else return mex(rc[u],mid+1,r,k);
}
}
using namespace segmenttree;
inline int pre(int x,int i){
auto p=lower_bound(f[x].begin(),f[x].end(),i);
if(p!=f[x].begin()) return *(--p);
return 0;
}
inline int nxt(int x,int i){
auto p=upper_bound(f[x].begin(),f[x].end(),i);
if(p!=f[x].end()) return *p;
return 0;
}
inline int askmx(int l,int r){
int len=lg[r-l+1];
return max(mx[len][l],mx[len][r-(1<<len)+1]);
}
inline int askmn(int l,int r){
int len=lg[r-l+1];
return min(mn[len][l],mn[len][r-(1<<len)+1]);
}
inline void init(){
for(int i=1;i<=n;i++) f[a[i]].emplace_back(i),rt[i]=ins(rt[i-1],0,n,a[i],i);
for(int i=1;i<=n;i++) g[a[i]==0].push_back({i,i});
for(int i=1;i<=n;i++){
for(auto tmp:g[i-1]){
int l=tmp.fir,r=tmp.sec;
int p=pre(i-1,l);
if(p) g[mex(rt[r],0,n,p)].push_back({p,r});
int q=nxt(i-1,r);
if(q) g[mex(rt[q],0,n,l)].push_back({l,q});
}
sort(g[i].begin(),g[i].end(),[](PI p,PI q){return p.fir==q.fir?p.sec<q.sec:p.fir>q.fir;});
vector <PI> gg;
int R=1e9;
for(auto tmp:g[i]){
if(R>tmp.sec) gg.emplace_back(tmp);
R=min(R,tmp.sec);
}
g[i]=gg;
}
for(int i=1;i<=n;i++) mx[0][i]=mn[0][i]=a[i];
for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
for(int j=1;j<=lg[n];j++)
for(int i=1;i+(1<<j)-1<=n;i++)
mx[j][i]=max(mx[j-1][i],mx[j-1][i+(1<<(j-1))]),
mn[j][i]=min(mn[j-1][i],mn[j-1][i+(1<<(j-1))]);
return ;
}
inline int get1(int l,int r,int p,int x){
int mid,ans=0;
while(l<=r){
mid=(l+r)>>1;
if(askmx(mid,p)<=x) ans=mid,r=mid-1;
else l=mid+1;
}
return ans;
}
inline int get2(int l,int r,int p,int x){
int mid,ans=0;
while(l<=r){
mid=(l+r)>>1;
if(askmx(p,mid)<=x) ans=mid,l=mid+1;
else r=mid-1;
}
return ans;
}
inline void work(){
for(int i=n;i>=0;i--){
for(auto tmp:g[i]){
int l=tmp.fir,r=tmp.sec;
int p=pre(i,l);
if(p) p++;
else p=1;
int q=nxt(i,r);
if(q) q--;
else q=n;
int L=max(1,get1(p,l,r,i+k)),R=min(n,get2(r,q,l,i+k));
if(L>R) continue;
aa[++m]={L,r,R,1};
aa[++m]={l+1,r,R,-1};
}
}
return ;
}
namespace Segmenttree{
int tg[2000005];
int mnn[2000005];
int ctt[2000005];
inline void pushup(int u){
mnn[u]=min(mnn[u<<1],mnn[u<<1|1]);
ctt[u]=0;
if(mnn[u]==mnn[u<<1]) ctt[u]+=ctt[u<<1];
if(mnn[u]==mnn[u<<1|1]) ctt[u]+=ctt[u<<1|1];
return ;
}
inline void build(int u,int l,int r){
ctt[u]=r-l+1;
if(l==r) return ;
int mid=(l+r)>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
return ;
}
inline void down(int u,int x){mnn[u]+=x;tg[u]+=x;return ;}
inline void pushdown(int u){
if(!tg[u]) return ;
down(u<<1,tg[u]);
down(u<<1|1,tg[u]);
tg[u]=0;
return ;
}
inline void updata(int u,int l,int r,int L,int R,int x){
if(L<=l&&r<=R){down(u,x);return ;}
pushdown(u);
int mid=(l+r)>>1;
if(L<=mid) updata(u<<1,l,mid,L,R,x);
if(R>mid) updata(u<<1|1,mid+1,r,L,R,x);
pushup(u);
}
}
using namespace Segmenttree;
inline void solveA(){
sort(aa+1,aa+1+m,[](node p,node q){return p.p==q.p?p.op<q.op:p.p<q.p;});
build(1,1,n);
int l=1;
for(int i=1;i<=n;i++){
while(aa[l].p<=i&&l<=m) updata(1,1,n,aa[l].l,aa[l].r,aa[l].op),l++;
ans+=n-(mnn[1]?0:ctt[1]);
}
return ;
}
inline void solveB(){
for(int i=1;i<=n;i++){
int l=1,r=i,mid,pos=i+1;
while(l<=r){
mid=(l+r)>>1;
if(askmn(mid,i)+k>=askmx(mid,i)) pos=mid,r=mid-1;
else l=mid+1;
}
ans+=i-pos+1;
}
return ;
}
inline void solveC(){
for(int i=1;i<=n;i++){
int l=1,r=i,mid,pos=i+1;
while(l<=r){
mid=(l+r)>>1;
if(k>=askmx(mid,i)) pos=mid,r=mid-1;
else l=mid+1;
}
ans-=i-pos+1;
}
return ;
}
int main(){
in(n),in(k);
for(int i=1;i<=n;i++) in(a[i]);
init();
work();
solveA();
solveB();
solveC();
printf("%lld\n",ans);
return 0;
}