CF1574F Occurrences 题解

· · 题解

CF1574F Occurrences

题意

定义 a 的子序列是 [a_l,a_{l+1},\cdots,a_r] 构成的序列,定义 a 序列在 b 序列中出现的次数为 b 序列中与 a 相同的子序列个数。

n 个值域在 [1,k] 的数列 A_1,A_2,\cdots,A_n,要求构造一个长为 m 的序列 a,满足对于所有数列 A_iA_i 在序列 a 中的出现次数大于等于其任意一个非空子序列在 a 中的出现次数。

求不同的 a 的数量,答案 \bmod 998244353

数据范围1\le n,m,k\le 3\times 10^5,\sum c_i\le 3\times 10^5

题解

容易发现,如果 a 中包含了 A_i 中的任意一个元素,一定需要把 A_i 全部填完,而且必须严格按照 A_i 里的顺序。也就是说,如果把相邻元素连有向边的话,我们要把 A_i 构成的链全部填进去。

其次,我们填完 A_i 之后,A_i 中可能还包含 A_j 中的元素,A_j 也需要填完,而且是对位。也就是说,我们需要把 A_i 的链和 A_j 的链合成在一起。例如 3\ 4\ 1\ 2 1\ 2\ 6,他们会变成一条新的链:3\ 4\ 1\ 2\ 6

当然,这条链上必须不能包含环,例如 1\ 2\ 11 的出现永远大于 1\ 2\ 1 的出现次数。这条链也不能有分支,例如 1\ 2\ 3 1\ 2\ 4,不管怎么填,总会把一条链断开,然后就不满足条件了。

所以,把图建出来之后,每个连通块分别考虑,只有该连通块是一条链的时候才能被选,所占的长度是连通块的大小(元素个数)。而不同的连通块大小最多 O(\sqrt k) 个,直接做 DP,枚举所有能够转移的长度。

具体的,记 g_i 表示长为 i 的转移有多少个。

f_i=\sum_{j=1}^{k}f_{i-k}

时间复杂度 O(m\sqrt k)

具体的实现上,使用并查集,维护每个集合的 sizetag 表示这个集合是不是一条链。同时维护每个元素的前驱后继,合并两条链时,如果出现了与原来不同的前驱或者后继,说明这个集合的链产生了分支,不再是一条链;如果连边时两个元素已经在一个集合,那么这个集合会出现环,也不再是一条链。

可以参考代码。

#include <bits/stdc++.h>
using namespace std;
const int N=3e5+5,mod=998244353;
inline void Add(int &a,int b){a+=(a+b>=mod)?b-mod:b;}
int fa[N],sz[N],to[N],a[N],n,m,k,pre[N],tag[N];
int find_fa(int x){return fa[x]==x?x:fa[x]=find_fa(fa[x]);}
inline void merge(int x,int y){
    int fy=find_fa(y),fx=find_fa(x);
    if(fy==fx){tag[fx]=1;return;}
    fa[fy]=fx;
    sz[fx]+=sz[fy];
    tag[fx]|=tag[fy];
}
int buc[N],tot,g[N],f[N];
int main(){
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1;i<=k;++i)fa[i]=i,sz[i]=1;
    for(int i=1,c;i<=n;++i){
        scanf("%d",&c);
        for(int j=1;j<=c;++j){
            scanf("%d",&a[j]);
            if(j==1)continue;
            if((pre[a[j]]&&pre[a[j]]!=a[j-1])||(to[a[j-1]]&&to[a[j-1]]!=a[j])){
                merge(a[j-1],a[j]);
                tag[find_fa(a[j])]=1;
            }else if(!to[a[j-1]]){
                to[a[j-1]]=a[j];
                pre[a[j]]=a[j-1];
                merge(a[j-1],a[j]);
            }
        }
    }
    for(int i=1;i<=k;++i){
        if(find_fa(i)!=i||tag[i])continue;
        ++g[sz[i]];
    }
    for(int i=1;i<=k;++i)if(g[i])buc[++tot]=i;
    f[0]=1;
    for(int i=0;i<=m;++i)
        for(int j=1;j<=tot&&i+buc[j]<=m;++j)
            f[i+buc[j]]=(f[i+buc[j]]+1ll*f[i]*g[buc[j]])%mod;
    printf("%d\n",f[m]);
    return 0;
}