CF1876D Lexichromatography

· · 题解

被诈骗了 & 不能觉得不会下分就过完 C 摆烂啊!

Description

给定长度为 n 的序列 a,要求将序列每个位置染成红色或蓝色,且满足以下条件:

求不同的染色方案数,答案对 998\,244\,353 取模。n,a_i\le 2\times 10^5

Solution

有字典序大小的限制是不好计算答案的,但是这条性质是诈骗。对于任意一种两序列不相等的染色方案,我们把红蓝对调一下,发现这两种方案里一定恰好有一个满足字典序的限制。

也就是说在满足第二条限制的情况下,设总染色方案数为 All,两序列相等的方案数为 cnt,则 ans=\frac{All-cnt}{2}

第二条限制实际就是说我们对于每个 x,把 a_i=x 的位置取出来拼成一个序列,序列里相邻的两个数颜色相反。令不同的 a_icol,对于每种 a_i 确定第一个的颜色就可以确定所有,即 All=2^{col}

考虑如何求红蓝序列相同的方案数。

首先如果某种 a_i 出现了奇数次就直接寄了,否则我们把 a_i 相同的位置从左到右两两分组形成一些线段,限制是每条线段的左右端点不同色。

分类讨论线段 [a,b][c,d] 的位置关系。
如果两条线段没有交集,它们怎么染色互不影响;如果有包含关系,怎么染色都不能让序列相等,令 cnt=0
否则就是它们有交集的情况,设 a<c<b<d。那么 a,c 一定同色,b,d 一定同色。

发现我们的限制的形式都是某两点颜色相同 / 相反,使用形如食物链一题的扩展域并查集维护。那么设连通块个数为 block,则 cnt=2^{block}

实现上,我们对每个线段不用向所有与它有交的线段连边。因为两个与它有交的线段之间也有交,它们一定之前已经连通。所以在其中随便找一个连即可,这样时间复杂度就对了,为 O(n\log n)

#define int long long 
const int N=4e5+5,mod=998244353;
const int inv2=((mod+1)>>1);
int n,a[N];
int lst[N],tot[N],fa[N];
int find(int x) {return fa[x]==x?x:fa[x]=find(fa[x]);}
void merge(int x,int y)
{
    if(find(x)==find(y)) return;
    fa[find(x)]=find(y);
}
struct node{int l,r;} e[N];
vector<int> t[N];
int cnt,all=1,ans=1,b[N];
signed main()
{
    n=read();
    for(int i=1;i<=n;i++) a[i]=b[i]=read();
    sort(b+1,b+n+1);
    for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+n+1,a[i])-b;
    for(int i=1;i<=(n<<1);i++) fa[i]=i;
    for(int i=1;i<=n;i++)
    {
        if(lst[a[i]])
        {
            merge(lst[a[i]],i+n),merge(lst[a[i]]+n,i);
            if(tot[a[i]]&1) e[++cnt]={lst[a[i]],i};
        }
        tot[a[i]]++,lst[a[i]]=i;
    }
    for(int i=1;i<=cnt;i++) t[e[i].l].push_back(e[i].r);
    for(int i=1;i<=n;i++) if(tot[i]) all=all*2%mod;
    for(int i=1;i<=n;i++) if(tot[i]&1) {printf("%lld\n",all*inv2%mod);return 0;}
    set<int> s;
    for(int i=1;i<=n;i++)
    {
        for(auto j:t[i])    
        {
            if(s.upper_bound(j)!=s.end()) {printf("%lld\n",all*inv2%mod);return 0;}
            if(!s.empty())
            {
                int r=*s.begin();
                merge(j,r),merge(j+n,r+n);
            }
            s.insert(j);
        }
        s.erase(i);
    }
    for(int i=1;i<=n;i++) if(find(i)==i) ans=(ans<<1)%mod;
    ans=(all-ans+mod)%mod*((mod+1)/2)%mod;
    printf("%lld\n",ans);
    return 0;
}