题解 CF1609F【Interesting Sections】

· · 题解

😅

题解

考虑分治。设当前要对区间 [l,r] 进行计算,令 m=\lfloor \frac{l+r}{2}\rfloor,我们只需要计算左端点 \in [l,m],右端点 \in [m+1,r] 的合法区间个数,然后递归到 [l,m],[m+1,r] 计算即可。

我们从大到小枚举一个左端点 i\in [l,m],对于一个固定的 i,在 [m+1,r] 之间一定会存在两个分界点 p_1,p_2 使得:

并且,随着 i 逐渐减小,p_1,p_2 都递增。于是我们可以双指针维护 p_1,p_2,并分别考虑三种情况的贡献。第一种和第三种的贡献是显然的,对于第二种,只需要在 p_1,p_2 移动过程中,对于每个 k 维护 \operatorname{popcount}=k 的个数有多少。

这样总复杂度是 \mathcal{O}(n(\log n+\log V)) 的,V\max_i\{a_i\}

代码

因为我用了 memset,所以代码实际上是 \mathcal{O}(n\log n\log V) 的,但 memset 的常数比较小,所以跑得挺快。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
#define For(Ti,Ta,Tb) for(int Ti=(Ta);Ti<=(Tb);++Ti)
#define Dec(Ti,Ta,Tb) for(int Ti=(Ta);Ti>=(Tb);--Ti)
#define Debug(...) fprintf(stderr,__VA_ARGS__)
#define popc __builtin_popcountll
#define Clear(a) memset(a,0,sizeof(a))
typedef long long ll;
const int N=1e6+5,LogV=65;
const ll Inf=0x3f3f3f3f3f3f3f3f;
int n,cnt1[LogV],cnt2[LogV],cnt3[LogV],cnt4[LogV],eq[N];ll a[N],ans;
void Solve(int l,int r){
    if(l>r) return;
    if(l==r){++ans;return;}
    int mid=(l+r)/2;
    Clear(cnt1),Clear(cnt2),Clear(cnt3),Clear(cnt4);
    eq[mid]=0;
    for(ll i=mid+1,mn=a[mid+1],mx=a[mid+1];i<=r;++i,mn=min(mn,a[i]),mx=max(mx,a[i]))
        eq[i]=eq[i-1]+(popc(mn)==popc(mx));
    ll mn=Inf,mx=-Inf,mn1=Inf,mx1=-Inf,mn2=Inf,mx2=-Inf;
    for(int i=mid,p1=mid,p2=mid;i>=l;--i){
        mn=min(mn,a[i]),mx=max(mx,a[i]);
        while(p1<r&&mn<=min(mn1,a[p1+1])&&mx>=max(mx1,a[p1+1]))
            ++p1,mn1=min(mn1,a[p1]),mx1=max(mx1,a[p1]),++cnt1[popc(mn1)],++cnt2[popc(mx1)];
        while(p2<r&&(mn<=min(mn2,a[p2+1])||mx>=max(mx2,a[p2+1])))
            ++p2,mn2=min(mn2,a[p2]),mx2=max(mx2,a[p2]),++cnt3[popc(mn2)],++cnt4[popc(mx2)];
        ans+=(popc(mn)==popc(mx))*(p1-mid)+eq[r]-eq[p2];
        if(mn<=mn2) ans+=cnt4[popc(mn)]-cnt2[popc(mn)];
        else ans+=cnt3[popc(mx)]-cnt1[popc(mx)];
    }
    Solve(l,mid),Solve(mid+1,r);
}
int main(){
    #ifndef zyz
    ios::sync_with_stdio(false),cin.tie(nullptr);
    #endif
    cin>>n;
    For(i,1,n) cin>>a[i];
    Solve(1,n);
    cout<<ans;
    return 0;
}