AC 自动机总结笔记

· · 题解

本文旨在帮助一个 pj 组选手理解 AC 自动机。

KMP

引入

给你两个串 s_1,s_2,求出 s_2s_1 中的出现位置。

你略微思考,写出了如下代码:

int n = s1.size(), m = s2.size();
for (int i = 0; i + m - 1 < n; i++) {
    if (s1.substr(i, m) == s2) cout << i << ' ';
}

这样一共要做 n(n-m) 次操作,复杂度为 O(n^2) 级别。

你经过思考,发现没有必要每次都暴力 substr,只需要设定一个匹配指针,如果不匹配就不用再判断了。

但是,字符串不是随机的,比如:

s1: 000000000...0000000000000000000000001
s2: 000000000...0000000001

这会使这个匹配算法退化到 O(n^2) 级别。

优化思路

针对字符串匹配,有两种优化思路:

  1. 优化比较的复杂度;
  2. 优化指针扫描。

第一种优化引出了字符串哈希,而 KMP 属于第二种优化。

KMP 基本思想

对于这样一个字符串 ababc,如果我们已经匹配过 ab 并失配,还有必要把指针退回去匹配吗?这时候,我们已经知道模式串的下一位是 b 不是 a,直接跳过去即可。

跳过去多少呢?我们引入 KMP 中的 next 数组。next 数组就记录了你要跳多少。这个东西在稍后的 AC 自动机中称作失配指针

假设我们已经求出了 next,就可以容易地解决上面的问题:

vector<int> kmp(const string& text, const string& patt) {
  int n = (int) text.size(), m = (int) patt.size();
  int j = -1;
  vector<int> pos;
  for (int i = 0; i < n; i++) {
    for (; j >= 0 && text[i] != patt[j + 1]; j = nxt[j]);  // 跳 next
    if (text[i] == patt[j + 1]) j++;
    if (j == m - 1) pos.push_back(i - m + 1);
  }
  return pos;
}

由于指针只会向前移动,KMP 算法的复杂度是 O(n+m) 的。

求解 next

如果直接暴力求 next,复杂度将退化到 O(m^3)

先给 next 一个严谨的定义:next_i 代表在第 1 到第 i-1 位中的前缀与后缀相同的部分最长是多长。

显然 next 具有如下性质:

你会发现,next 实际上是串自己与自己匹配的过程。做法也是很类似的:

int nxt[N];

inline void get_next(const string& patt) {
  memset(nxt, 0, sizeof(nxt));
  int n = (int) patt.size();
  nxt[0] = -1;
  int j = -1;
  for (int i = 1; i < n; i++) {
    for (; j >= 0 && patt[i] != patt[j + 1]; j = nxt[j]);
    if (patt[i] == patt[j + 1]) j++;
    nxt[i] = j;
  }
}

字典树

字典树(Trie)是通过把一些字符串的相同前缀压缩在一起,达到节省时间 / 空间的目的。

字典树是维护字符串集合的主要工具,下图展示了一棵字典树:

AC 自动机

引入

AC 自动机的本质是在字典树上做 KMP。

对于本题,如果我们对每个 T 都做 KMP,其复杂度为 O(nm),这是无法接受的。

考虑用一棵字典树把 T 压缩起来,字典树上做 KMP。

fail 指针

字典树上的 fail 指针对应了 KMP 的 next 数组。

采用 BFS 层次遍历,对于节点 u,设其父亲为 ffu 的边上字符为 c

若存在 fail_{f,c},对应 next 数组,令 fail_{u,c} \gets fail_{f,c} 即可。

若不存在,就暴力向上跳,一直跳到有为止。反之,指向根。

匹配过程

从起点开始匹配,若失配,就跳到 fail 指针所在节点再次匹配。

拓扑排序优化

暴力构建 fail 的复杂度是 O(m^2) 的,需要优化。

我们把所有 fail 单独拿出来构成一棵树,直接拓扑排序即可。

代码

最后放个这题的代码吧:

#include <bits/stdc++.h>
using namespace std;

#define N 2000005

int n;
string s, t;

int cnt = 1, mp[N];
struct node {
    int son[26], fail, flag, res;
    void clear() {memset(son, 0, sizeof(son)), fail = flag = res = 0;}
} trie[N];

inline void insert(const string& s, int num) {
    int u = 1, n = s.size();
    for (int i = 0; i < n; i++) {
        int v = s[i] - 'a';
        if (!trie[u].son[v]) trie[u].son[v] = ++cnt;
        u = trie[u].son[v];
    }
    if (!trie[u].flag) trie[u].flag = num;
    mp[num] = trie[u].flag;
}

queue<int> q;
int indeg[N];

inline void build() {
    for (int i = 0; i < 26; i++) trie[0].son[i] = 1;
    q.emplace(1);
    while (!q.empty()) {
        int u = q.front(); q.pop();
        int fail = trie[u].fail;
        for (int i = 0; i < 26; i++) {
            int v = trie[u].son[i];
            if (!v) {
                trie[u].son[i] = trie[fail].son[i];
                continue;
            }
            trie[v].fail = trie[fail].son[i];
            indeg[trie[v].fail]++;
            q.emplace(v);
        }
    }
}

void query(const string& s) {
    int u = 1, n = s.size();
    for (int i = 0; i < n; i++) {
        u = trie[u].son[s[i] - 'a'];
        trie[u].res++;
    }
}

int vis[N];
inline void toposort() {
    for (int i = 1; i <= cnt; i++) {
        if (indeg[i] == 0) q.emplace(i);
    }
    while (!q.empty()) {
        int u = q.front(); q.pop();
        vis[trie[u].flag] = trie[u].res;
        int v = trie[u].fail; 
        trie[v].res += trie[u].res;
        if (--indeg[v] == 0) q.emplace(v);
    }
}

void _main() {
    cin >> t;
    cin >> n;
    cnt = 1;
    for (int i = 1; i <= n; i++) {
        cin >> s;
        insert(s, i);
    }
    build();

    query(t);
    toposort();
    for (int i = 1; i <= n; i++) cout << vis[mp[i]] << '\n';
} signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);

    _main();
    return 0;
}