题解 P4062 【[Code+#1]Yazid 的新生舞会】

· · 题解

讲一个暴力点的做法。

首先发现如果一个数字是 [l,r] 之间的众数的话,那么这个数字肯定是 [l,mid] 之间的众数或者是 [mid+1,r] 之间的众数。(其中 mid=\left\lfloor\frac{l+r}{2}\right\rfloor

于是就可以分治,对于每一个区间 [l,r] 计算出经过 mid 的满足题意的区间数。

直接对于每个数字暴力计算是 \mathcal{O}(n^2 \operatorname{log}^2 n) 的,肯定过不去,考虑优化。

可以考虑直接挑出可能成为经过 mid 的区间的众数的数字。发现这个数字数量是 \log n 级别的,于是优化到 \mathcal{O}(n\operatorname{log}^3n)

你可能觉得3只log不可能过5e5,但因为实际上压根跑不满,于是它过了,而且还跑得飞快,还跑得比 \sf\color{black}{L}\color{red}{imit} 的线段树做法快...

至于接下来怎么去做,相信很多题解里已经讲过了,我这里再复述一遍。

挑出这些数字以后,我们把问题转化为:给定一个数字 x,求区间 [l,r] 中有多少子区间经过 mid,而且满足区间众数是 x 且区间满足题意。

根据摩尔投票法,我们可以把每一个等于 x 的地方转为 1,否则就是 -1,问题变成了求有多少子区间经过 mid,而且满足区间和大于0。

考虑把区间 [l,r] 前缀和一下,设前缀和后第 i 个点的结果为 sum_i,问题变成了求

\sum_{i=mid}^r \sum_{j=l}^{mid-1} [sum_i>sum_j]

其中 [p] 表示布尔表达式,即 p 是否为真。

这坨东西发现可以轻松用树状数组搞定。然后就没有然后了。

不过发现区间大小为 1 的时候容易重复计算,那么就直接加个区间大小大于 1 的限制即可,即求:

\sum_{i=mid+1}^r \sum_{j=l}^{mid-1} [sum_i>sum_j]

最后加上大小为 1 的区间数量即可。时间复杂度 \mathcal{O}(n\operatorname{log}^3n)

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int N=500000+5;
int n,tree[N<<2];
int limit;
void add(int x,int c){
    x+=n+1;
    while(x<=limit) tree[x]+=c,x+=(x&-x);  
}
int query(int x){
    int res=0;x+=n+1;
    while(x>0) res+=tree[x],x-=(x&-x);
    return res;
}
long long ans;
int cnt[N],a[N];
set<int>::iterator it;
void calc(int x,int l,int mid,int r){
    int sum=0;
    add(0,1);
    for(int i=l;i<=mid;++i){
        sum+=(a[i]==x?1:-1);
        if(i!=mid) add(sum,1);
        //cout<<sum<<" A"<<endl;
    }
    for(int i=mid+1;i<=r;++i)
        sum+=(a[i]==x?1:-1),
        //cout<<sum<<" B\n",
        ans+=query(sum-1);
    sum=0;
    for(int i=l;i<mid;++i){
        sum+=(a[i]==x?1:-1);
        add(sum,-1);
    }
    add(0,-1);
}
void query(int l,int r){
    if(l==r) return ++ans,void(0);
    int mid=(l+r)>>1;
    query(l,mid);query(mid+1,r);
    set<int> s;
    for(int i=mid+1,mx=0;i<=r;++i){
        ++cnt[a[i]];
        if(cnt[a[i]]>cnt[mx]) mx=a[i],s.insert(mx);
    }
    for(int i=mid+1;i<=r;++i) --cnt[a[i]];
    for(int i=mid,mx=0;i>=l;--i){
        ++cnt[a[i]];
        if(cnt[a[i]]>cnt[mx]) mx=a[i],s.insert(mx);
    }
    //out<<l<<" "<<r<<" seg\n";
    //for(int i=l;i<=r;++i)cout<<a[i]<<" ";cout<<endl;
    for(int i=mid;i>=l;--i) --cnt[a[i]];
    for(it=s.begin();it!=s.end();++it)
        //cout<<l<<" "<<r<<" "<<(*it)<<" "<<ans<<endl,
        calc(*it,l,mid,r);
}
int main(){
    int type;
    scanf("%d%d",&n,&type);limit=n+n+1;
    for(int i=1;i<=n;++i)
        scanf("%d",&a[i]),a[i]=a[i]+1;
    query(1,n);
    cout<<ans;
    return 0;
}