题解:AT_abc458_f [ABC458F] Critical Misread
ysw_2029
·
·
题解
前置知识:
给你 K 个串 S_1,S_2,\dots,S_K,问你有多少个长度为 n 的字符串满足不包含任何一个串作为其子串(连续子序列)。
---
看到子串,显然会想到 AC 自动机。
考虑对这 $K$ 个串建立 AC 自动机。记 $S$ 为 AC 自动机上的结点数。发现 $S=\sum\limits_{i=1}^K|S_i|\le100$,但是 $N$ 很大,则我们考虑在 AC 自动机的状态上设 dp。
不妨设 $f_{a,i,b}$ 表示从 AC 自动机的 $a$ 状态通过转移 $2^i$ 步后到达 $b$ 状态的方案数。当 $i=0$ 时,我们可以对 AC 自动机的每一位及它们的儿子之间进行统计。然后,我们可以枚举状态 $a,b,c$,若当前转移到第 $i$ 层,根据乘法原理,状态转移方程为 $f_{a,i+1,c}\leftarrow f_{a,i+1,c}+f_{a,i,b}\times f_{b,i,c}$。
这样我们就可以处理出 AC 自动机上任意两点间转移的方案数了。
然后考虑如何统计答案。我们设 $g_{1_a}$ 表示当前到达 AC 自动机上状态 $a$ 的方案数。显然最初时 $g_{1_0}=1$,其余为 $0$。然后我们枚举 $N$ 在二进制下的每一位,如果 $N$ 的第 $i$ 位是 $1$ 的话,则进行一次转移。我们枚举当前状态 $a$ 和我们转移到的状态 $b$,并设临时数组 $g_{2_a}$ 表示转移第 $i$ 位 AC 自动机各点的方案数,那么状态转移方程显然为 $g_{2_b}\leftarrow g_{2_b}+g_{1_a}\times f_{a,i,b}$。随后再令 $g_1=g_2$ 并把 $g_2$ 清空即可。
这样的话,在 $N$ 的每一位都枚举完后,答案就为 $\sum\limits_{i=1}^Sg_{1_i}$。
记得取模。记得开 long long。还有,AC 自动机一定要标记哪些点不可达,在转移 dp 时一定要完全避开这些点。
:::success[AC code]
```
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=102,mod=998244353;
int n,k,cnt,f[N][32][N],g1[N],g2[N],ans;
string s;
struct ACM//AKM
{
int fail,son[27];
bool flag;
}a[N*N];
inline void add(string s)
{
int now=0;
for(auto i:s)
{
int&x=a[now].son[~-i^96];
if(!x)x=++cnt;
a[x].flag|=a[now].flag,now=x;
}
a[now].flag=1;
}
inline void build()
{
queue<int>q;
for(int i=0;i<26;i=-~i)
{
if(!a[a[0].son[i]].flag)f[0][0][a[0].son[i]]++;
if(a[0].son[i])q.push(a[0].son[i]);
}
while(!q.empty())
{
int x=q.front();
q.pop();
a[x].flag|=a[a[x].fail].flag;
for(int i=0;i<26;i++)
{
if(a[x].son[i])
a[a[x].son[i]].fail=a[a[x].fail].son[i],q.push(a[x].son[i]);
else a[x].son[i]=a[a[x].fail].son[i];
if(!a[x].flag&&!a[a[x].son[i]].flag)f[x][0][a[x].son[i]]++;
}
}
}
signed main()
{
cin>>n>>k;
for(int i=1;i<=k;i++)
{
cin>>s;
add(s);
}
build();
for(int b=0;b<30;b++)
for(int i=0;i<=cnt;i++)if(!a[i].flag)
for(int j=0;j<=cnt;j++)if(!a[j].flag&&f[i][b][j])
for(int k=0;k<=cnt;k++)if(!a[k].flag&&f[j][b][k])
f[i][b+1][k]=(f[i][b+1][k]+f[i][b][j]*f[j][b][k])%mod;
g1[0]=1;
for(int b=30;~b;b++)if(n&(1ll<<b))
{
for(int i=0;i<=cnt;i++)if(g1[i]&&!a[i].flag)
for(int j=0;j<=cnt;j++)if(f[i][b][j]&&!a[j].flag)
g2[j]=(g2[j]+g1[i]*f[i][b][j])%mod;
for(int i=0;i<=cnt;i++)if(!a[i].flag)g1[i]=g2[i],g2[i]=0;
}
for(int i=0;i<=cnt;i++)if(!a[i].flag)ans=(ans+g1[i])%mod;
cout<<ans
return 0;
}
```
:::