恶补 AC 自动机

· · 算法·理论

前言

听说最近洛谷文艺复兴,正好快到 CSP-J/S 复赛的时间了,还是恶补一下 AC 自动机吧。

温馨提示:本文图少,建议准备好纸笔,一边看一边画图模拟。

同步发布至简陋的博客。

前置知识

理论构造

AC(Aho–Corasick)自动机实际上是以 Trie 为基础,加之 KMP 失配指针的思想,可以进行多模式串匹配,其实差不多就是将原本“序列形”的 KMP 搬到了树上。

那么,我们也把 AC 自动机分为关于 Trie 与 KMP 的两部分进行构建。

以下默认字符串下标从 1 开始,(u,c) 表示 Trie 树上节点 u 通过字符 c 指向的节点,Trie 树的根节点为 0

构建字典树

对于一个模式串集,我们可以按照普通的 Trie 的放大将其全部插入到字典树上。

值得注意的是,自动机中强调状态,而对于在 AC 自动机中 Trie 的一个节点 u 而言,u 点代表的状态实际是根节点到 u 组成的字符串,某个模式串的前缀。转移即为 Trie 上的树边,字符集即是 Trie 的字符集。

构建失配指针

首先我们来回忆 KMP 中如何构建,若有一字符串 s,对于当前求解的 i 而言:

同理,我们将 s[1,i] 这段前缀看做 Trie 树上一条从根结点开始的一条链,对于结点 i,其父亲为 ff 通过字符 c 指向 i。仿照 KMP,有如下步骤:

其实自己画一画图,自己模拟一遍,可以发现和 KMP 极为相似。

偷一张 oiwiki 的图。

代码实现

求解失配指针

那么我们代码如何实现呢?可以发现对于任意节点 ifail_i 的深度一定比其浅,联想到 Trie 是树的结构,所以不难想到使用拓扑排序递推求解。

构建 Trie 的我就不放了,放构建 fail 的吧。

void GetFail(){
  queue<int> que;
  for(int i = 0; i < MAXV; i++){
    if(nxt[0][i]){
      fail[nxt[0][i]] = 0;
      que.push(nxt[0][i]);
    }
  }
  //为什么将根节点的儿子放进队列?
  //不妨画一棵 Trie,手动模拟一下,可以发现若将根节点入队,则它们的 fail 指针置为了本身,而非根节点
  for(; !que.empty(); ){
    int u = que.front(); que.pop();
    for(int i = 0; i < MAXV; i++){
      int v = nxt[u][i];
      if(!v){
        nxt[u][i] = nxt[fail[u]][i];
        //为什么 v 为空的时候要置为 nxt[fail[u]][i]?
        //其实这里应该理解为若在 u 后再追加字符 c,跳 fail 指针时会跳到的终点,即构建 fail 时第二种情况,也方便了下面求解 fail 指针
      }else{
        fail[v] = nxt[fail[u]][i];
        que.push(v);
      }
    }
  }
  //请注意,因为我们将一些原本空的节点改变了,所以如此下来此时已经不具有 Trie 的特点,现在这张图并非字典树原本的结构了
}

多模式串匹配

假设我们有 n 个模式串 s_1,s_2,\dots,s_n,和一个文本串 t,我们需要对于每个模式串统计在文本串中的出现次数。

n = 1,显然我们知道可以使用 KMP,那么,n > 1 时自然需要用 AC 自动机了。

读者不妨先联想 Trie 的结构和 KMP 匹配的方法,来大致猜测代码。

对于查询的文本串 t,我们不妨在根据模式串构建的 Trie 上跑一遍,对于经过的每一个状态,其可以通过 fail 跳到的模式串显然都可以和 t 匹配上,所以有如下代码:

for(char ch : s){
  for(int i = nxt[pos][ch - 'a']; i; i = fail[i]){
    if(exist[i]) vis[exist[i]]++;//出现次数增加
    //exist[i] 为结尾为 i 节点的模式串编号 
  }
  pos = nxt[pos][ch - 'a'];
}

同理,你还可以魔改这个查询,基本原理大致一样。

效率优化

AC 自动机固然好,但是仔细观察,可以发现复杂度很高,复杂度为 O((\sum\limits_{i=1}^{n} \vert s_i \vert )\vert t\vert)

不过,注意到每个节点 i 唯一对应 fail_i(根节点除外,我们默认 fail_0=0),所以,如果建一条 fail_i \to i 的边,其实这是一棵树,我们称之为“失配树”。

那么,回忆多模式匹配,可以发现,每次跳 fail,其实就是在失配树上往祖先跳,我们可以将其挂在失配树上,则对于每个模式串 s,它在文本串中出现的次数其实就是 s 在失配树上对应的节点的子树和,当然你也可以树上差分。

for(char ch : s){
  pos = nxt[pos][ch - 'a'];
  sum[pos]++;//打标记
}

总之,还有其他一些常用优化。例如拓扑排序,我们就是用它求的 fail 指针。再比如 DFS 序、深度,既然是树自然想到它们,转换为序列问题或子树问题再用数据结构维护。还比如可以离线,将询问在失配树上标记,再遍历失配树算贡献。

如此,就可以通过 P5357 了,时间复杂度 O(\sum\limits_{i=1}^{n} \vert s_i \vert + \vert t \vert)

#include<bits/stdc++.h>

using namespace std;
using ll = long long;

const int MAXN = 2e5 + 5, MAXLEN = 2e5 + 5, MAXV = 26;

string s, t[MAXN];
int n, nxt[MAXLEN][MAXV], fail[MAXLEN], cnt, ans[MAXN], sum[MAXLEN];
vector<int> exist[MAXLEN];
vector<int> G[MAXLEN];

void Insert(const string &s, int id){
  int pos = 0;
  for(char ch : s){
    if(!nxt[pos][ch - 'a']) nxt[pos][ch - 'a'] = ++cnt;
    pos = nxt[pos][ch - 'a'];
  }
  exist[pos].push_back(id);//注意可能有重复字符串
}

void GetFail(){//构建 fail
  queue<int> que;
  for(int i = 0; i < MAXV; i++){
    if(nxt[0][i]){
      fail[nxt[0][i]] = 0;
      que.push(nxt[0][i]);
    }
  }
  for(; !que.empty(); ){
    int u = que.front(); que.pop();
    for(int i = 0; i < MAXV; i++){
      int v = nxt[u][i];
      if(!v){
        nxt[u][i] = nxt[fail[u]][i];
      }else{
        fail[v] = nxt[fail[u]][i];
        que.push(v);
      }
    }
  }
}

void DFS(int u){
  for(int v : G[u]){
    DFS(v);
    sum[u] += sum[v];//子树和
  }
  for(int x : exist[u]){
    ans[x] = sum[u];
  }
}

int main(){
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  cin >> n;
  for(int i = 1; i <= n; i++){
    cin >> t[i];
    Insert(t[i], i);
  }
  cin >> s;
  GetFail();
  int pos = 0;
  for(char ch : s){
    pos = nxt[pos][ch - 'a'];
    sum[pos]++;//失配树上打标记
  }
  for(int i = 1; i <= cnt; i++){
    G[fail[i]].push_back(i);//建 fail 树
  }
  DFS(0);
  for(int i = 1; i <= n; i++){
    cout << ans[i] << "\n";
  }
  return 0;
}

自动机上 DP

考试中,基本除了纯考字符串用 AC 自动机,很多更是 AC 自动机上 DP。

以 P4052 为例,你需要统计有多少长度为 m 的,由大写字母组成的字符串 s,满足 n 个模式串中至少有一个模式串在 s 中出现过,取模 10^4+7

首先我们转换题意,令 c 为不满足条件的字符串数量,则答案为 26^m-c,如何计算 c 呢?

计数题,显然需要使用 DP,而我们又不能记录完整的字符串。我们不妨对模式串建立 AC 自动机,并设立状态 dp_{i,p} 表示已经决定了前 i 个字符,此时的字符串在 Trie 上对应的节点编号,则转移为 dp_{i,p} \to dp_{i+1,(p,c)},其中 c 为大写字母,且需要满足追加字符 c 后没有模式串出现在现在的字符串里。

那么,我们可以写出如下代码:

for(int i = 0; i < m; i++){
  for(int p = 0; p <= tot; p++){//tot 为字典树节点数量
    for(char ch = 'A'; ch <= 'Z'; ch++){//枚举追加字符
      int next_p = next[p][ch - 'A'];//追加后的节点编号
      bool flag = 0;
      for(int x = next_p; x; x = fail[x]){//跳 fail,看是否有模式串出现
        flag |= exist[x];//exist[x] 为节点 x 是否有模式串以其结尾
      }
      if(!flag) dp[i + 1][next_p] += dp[i][p];//转移
    }
  }
}

但是,复杂度爆炸,我们如何优化?瓶颈在于暴力跳 fail

自然有聪明的读者知道,我们可以用拓扑排序对于每个 Trie 上的节点 i 预先处理。可以在求 fail 时顺便递推处理。

#include<bits/stdc++.h>

using namespace std;
using ll = long long;
using pii = pair<int, int>;

const int MAXN = 65, MAXM = 1e2 + 1, MAXLEN = 6e3 + 5, MAXL = 26, MOD = 1e4 + 7;

int n, m, nxt[MAXLEN][MAXL], fail[MAXLEN], tot;
short dp[MAXM][MAXLEN];
bool exist[MAXLEN];
string s;

void Insert(const string &s){
  int pos = 0;
  for(char ch : s){
    if(!nxt[pos][ch - 'A']) nxt[pos][ch - 'A'] = ++tot;
    pos = nxt[pos][ch - 'A'];
  }
  exist[pos] = 1;
}

void GetFail(){//求 fail
  queue<int> que;
  for(int i = 0; i < MAXL; i++){
    if(nxt[0][i]){
      que.push(nxt[0][i]);
      fail[nxt[0][i]] = 0;
    }
  }
  for(; !que.empty(); ){
    int u = que.front(); que.pop();
    exist[u] |= exist[fail[u]];//递推,由 fail[u] -> u
    for(int i = 0; i < MAXL; i++){
      int v = nxt[u][i];
      if(v){
        que.push(v);
        fail[v] = nxt[fail[u]][i];
      }else{
        nxt[u][i] = nxt[fail[u]][i];
      }
    }
  }
}

int main(){
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  cin >> n >> m;
  for(int i = 1; i <= n; i++){
    cin >> s;
    Insert(s);
  }
  GetFail();
  dp[0][0] = 1;
  for(int i = 0; i < m; i++){
    for(int p = 0; p <= tot; p++){
      for(int ch = 0; ch < MAXL; ch++){
        if(!exist[nxt[p][ch]]) (dp[i + 1][nxt[p][ch]] += dp[i][p]) %= MOD;//合法则转移
      }
    }
  }
  int ans = 0;
  for(int i = 0; i <= tot; i++){
    (ans += dp[m][i]) %= MOD;
  }
  int res = 1;
  for(int i = 1; i <= m; i++){
    res = res * 26 % MOD;
  }
  cout << (res - ans + MOD) % MOD;
  return 0;
}

其实不止 AC 自动机,理应所有自动机都可以 DP,因为 DP 和自动机都有共同点——状态和转移。

练习题

后记

其实自己写完这篇文章对 AC 自动机的理解也加深了很多,解决了原来的一些疑惑。字符串这种东西真只能自己画图模拟理解,不然特别的模糊抽象。

第一次写这种文章,如有漏洞还请指出。

马上 CSPJ/S 复赛了,祝各位 RP++。