CF1766E

· · 题解

给定一个长为 n 的数组 a 。定义一个区间的权值为:有一个数组的链表,初始为空,对于区间从左到右的每个元素 a_i ,对链表左到右枚举每个数组,如果 a_i 与某个数组末尾元素的二进制AND不为 0 ,则将这个元素放到这个数组的末尾,并停止枚举。如果找不到这样的数组,则新建一个只有一个元素 a_i 的数组,并将这个数组放在链表的尾部,最后链表中数组的个数即是区间的权值。求数组 a 每个区间的权值之和。数据范围 1\le n\le 3\times 10^5,0\le a_i\le 3

由于 a_i 很小,考虑对 a_i 分类讨论。若 a_i=0 ,链表中一定不存在满足条件的数组,每次加入 0 都一定会导致链表长度增加 1 。所以每个 0 对答案的贡献都是包含这个 0 的区间个数,可以单独计算。

不考虑 0 后,由于 31,2,3 的二进制AND都不为 0 ,所以每次加入 3 都一定加入链表从左到右第一个数组中。又因为 12 二进制AND为 0 ,也就是,在链表第一个数组后面,只能存在全部为 1 的数组和全部为 2 的数组,且每种最多一个(后文称为纯 1 数组和纯 2 数组)。

考虑枚举区间的左端点,对其右侧的所有右端点计算贡献和。对于每个左端点,在其向右添加元素时,只要遇到非 0 的数,第一个数组就会出现,并对其右侧所有的右端点产生 1 的贡献。

当第一个数组的尾部为 2 时,如果添加 1 ,纯 1 数组就会出现。如果要让第一个数组尾部为 2 ,只有两种可能:区间从左到右第一个非 0 元素为 2 ,或者在遇到某个 3 后遇到的下一个非 0 元素即为 2 。这之后,如果在未遇到 3 时遇到 1 ,就会对其右侧所有的右端点产生 1 的贡献。纯 2 数组同理。为在 O(1) 时间解决此问题,需要预处理以下 5 个数组。

每个左端点 i 在计算贡献时

总时间复杂度为 O(n)

#include <bits/stdc++.h>
using namespace std;
int a[300000],nval[300000],n1[300000],n2[300000],n31[300000],n32[300000];
int main(int argc, char** argv) {
    ios::sync_with_stdio(false),cin.tie(0);
    long long ans=0,n,i,t;
    cin>>n;
    for(i=0;i<n;i++)
    {
        cin>>a[i];
        if(a[i]==0)ans+=(long long)(i+1)*(n-i);
    }
    for(i=n-1;i>-1;i--)
    {
        if(a[i]>0)nval[i]=i;
        else
        {
            if(i==n-1)nval[i]=-1;
            else nval[i]=nval[i+1];
        }
    }
    for(i=n-1;i>-1;i--)
    {
        if(a[i]==1)n1[i]=i;
        else if(a[i]==0||a[i]==2)
        {
            if(i==n-1)n1[i]=-1;
            else n1[i]=n1[i+1];
        }
        else n1[i]=-1;
    }
    for(i=n-1;i>-1;i--)
    {
        if(a[i]==2)n2[i]=i;
        else if(a[i]==0||a[i]==1)
        {
            if(i==n-1)n2[i]=-1;
            else n2[i]=n2[i+1];
        }
        else n2[i]=-1;
    }
    for(i=n-1;i>-1;i--)
    {
        if(i<n-1&&a[i]==3&&nval[i+1]!=-1&&a[nval[i+1]]==2&&n1[nval[i+1]]!=-1)n31[i]=n1[nval[i+1]];
        else if(i==n-1)n31[i]=-1;
        else n31[i]=n31[i+1];
    }
    for(i=n-1;i>-1;i--)
    {
        if(i<n-1&&a[i]==3&&nval[i+1]!=-1&&a[nval[i+1]]==1&&n2[nval[i+1]]!=-1)n32[i]=n2[nval[i+1]];
        else if(i==n-1)n32[i]=-1;
        else n32[i]=n32[i+1];
    }
    for(i=0;i<n;i++)
    {
        t=nval[i];
        if(t==-1)continue;
        ans+=n-t;
        if(a[t]==1&&n2[t]!=-1)ans+=n-n2[t];
        else if(n32[t]!=-1)ans+=n-n32[t];
        if(a[t]==2&&n1[t]!=-1)ans+=n-n1[t];
        else if(n31[t]!=-1)ans+=n-n31[t];
    }
    cout<<ans<<'\n';
    return 0;
}