题解:AT_abc458_f [ABC458F] Critical Misread

· · 题解

前置知识:

给你 K 个串 S_1,S_2,\dots,S_K,问你有多少个长度为 n 的字符串满足不包含任何一个串作为其子串(连续子序列)。

--- 看到子串,显然会想到 AC 自动机。 考虑对这 $K$ 个串建立 AC 自动机。记 $S$ 为 AC 自动机上的结点数。发现 $S=\sum\limits_{i=1}^K|S_i|\le100$,但是 $N$ 很大,则我们考虑在 AC 自动机的状态上设 dp。 不妨设 $f_{a,i,b}$ 表示从 AC 自动机的 $a$ 状态通过转移 $2^i$ 步后到达 $b$ 状态的方案数。当 $i=0$ 时,我们可以对 AC 自动机的每一位及它们的儿子之间进行统计。然后,我们可以枚举状态 $a,b,c$,若当前转移到第 $i$ 层,根据乘法原理,状态转移方程为 $f_{a,i+1,c}\leftarrow f_{a,i+1,c}+f_{a,i,b}\times f_{b,i,c}$。 这样我们就可以处理出 AC 自动机上任意两点间转移的方案数了。 然后考虑如何统计答案。我们设 $g_{1_a}$ 表示当前到达 AC 自动机上状态 $a$ 的方案数。显然最初时 $g_{1_0}=1$,其余为 $0$。然后我们枚举 $N$ 在二进制下的每一位,如果 $N$ 的第 $i$ 位是 $1$ 的话,则进行一次转移。我们枚举当前状态 $a$ 和我们转移到的状态 $b$,并设临时数组 $g_{2_a}$ 表示转移第 $i$ 位 AC 自动机各点的方案数,那么状态转移方程显然为 $g_{2_b}\leftarrow g_{2_b}+g_{1_a}\times f_{a,i,b}$。随后再令 $g_1=g_2$ 并把 $g_2$ 清空即可。 这样的话,在 $N$ 的每一位都枚举完后,答案就为 $\sum\limits_{i=1}^Sg_{1_i}$。 记得取模。记得开 long long。还有,AC 自动机一定要标记哪些点不可达,在转移 dp 时一定要完全避开这些点。 :::success[AC code] ``` #include<bits/stdc++.h> #define int long long using namespace std; const int N=102,mod=998244353; int n,k,cnt,f[N][32][N],g1[N],g2[N],ans; string s; struct ACM//AKM { int fail,son[27]; bool flag; }a[N*N]; inline void add(string s) { int now=0; for(auto i:s) { int&x=a[now].son[~-i^96]; if(!x)x=++cnt; a[x].flag|=a[now].flag,now=x; } a[now].flag=1; } inline void build() { queue<int>q; for(int i=0;i<26;i=-~i) { if(!a[a[0].son[i]].flag)f[0][0][a[0].son[i]]++; if(a[0].son[i])q.push(a[0].son[i]); } while(!q.empty()) { int x=q.front(); q.pop(); a[x].flag|=a[a[x].fail].flag; for(int i=0;i<26;i++) { if(a[x].son[i]) a[a[x].son[i]].fail=a[a[x].fail].son[i],q.push(a[x].son[i]); else a[x].son[i]=a[a[x].fail].son[i]; if(!a[x].flag&&!a[a[x].son[i]].flag)f[x][0][a[x].son[i]]++; } } } signed main() { cin>>n>>k; for(int i=1;i<=k;i++) { cin>>s; add(s); } build(); for(int b=0;b<30;b++) for(int i=0;i<=cnt;i++)if(!a[i].flag) for(int j=0;j<=cnt;j++)if(!a[j].flag&&f[i][b][j]) for(int k=0;k<=cnt;k++)if(!a[k].flag&&f[j][b][k]) f[i][b+1][k]=(f[i][b+1][k]+f[i][b][j]*f[j][b][k])%mod; g1[0]=1; for(int b=30;~b;b++)if(n&(1ll<<b)) { for(int i=0;i<=cnt;i++)if(g1[i]&&!a[i].flag) for(int j=0;j<=cnt;j++)if(f[i][b][j]&&!a[j].flag) g2[j]=(g2[j]+g1[i]*f[i][b][j])%mod; for(int i=0;i<=cnt;i++)if(!a[i].flag)g1[i]=g2[i],g2[i]=0; } for(int i=0;i<=cnt;i++)if(!a[i].flag)ans=(ans+g1[i])%mod; cout<<ans return 0; } ``` :::