题解 P5840【[COCI2015] Divljak】

· · 题解

Description

n 个字符串 S_1, S_2, \cdots, S_n。以及一个字典 T,一开始字典 T 为空。

接下来有 q 个操作,操作包含以下两种:

数据范围:1 \leq n, q \leq 10^5,字符集为小写字母集,字符串总长 \leq 2 \times 10^6
时空限制:4000 \ \text{ms} / 500 \ \text{MiB}

Solution

操作 2 是一个多模匹配问题。

考虑将 S_1, S_2, \cdots, S_n 建出 AC 自动机。
对 AC 自动机上的每一个点 u,求出它的失配指针 fail_u,将 fail_u \to u 连边,即可得到一棵 fail 树。

对于字典 T 中新插入的一个字符串 P,考虑求 PS_1, S_2, \cdots, S_n 中的哪些字符串造成贡献。

考虑文本串 P 匹配的过程:P 在 AC 自动机上一个字符一个字符走的过程,相当于枚举了一个前缀,任意时刻在 AC 自动机(trie 图)上走到的节点 u 代表的字符串,即为该前缀与自动机匹配的最长后缀。那么,我们考虑在 u 节点这个位置向上跳 fail 指针,根据 fail 指针的定义,路径上经过的每一个节点代表的字符串都是 P 的子串。

P 在 AC 自动机上依次走到了节点 u_1, u_2, \cdots, u_k
那么 P 在 AC 自动机上能匹配到的子串位于 u_1, u_2, \cdots, u_kfail 树上到根节点上的链的点集并。
那么现在要做的是将该点集内的所有点的答案加上 1,本质上是一个树链求并。

注意到根节点是固定的,可以考虑将 u_1, u_2, \cdots, u_k 按照在 fail 树中的 dfs 序排序后,做下面的事情:

现在问题转化为了:" 路径加 " & " 单点求值 "。
可以使用树上差分将问题转化为:" 单点加 " & " 子树求和 "。

利用在 dfs 序上维护一个树状数组即可实现上述操作。

时间复杂度即为线性对数(线性函数乘上对数函数)。

Code

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue> 

using namespace std;

const int N = 100100, SIZE = 2001000;

int n, m;

char S[N]; 

int cT = 1;
struct AC {
    int trans[26];
    int fail;
} t[SIZE];

int End[N];

void insert(char *S, int id) {
    int p = 1, len = strlen(S + 1);
    for (int i = 1; i <= len; i ++) {
        int v = S[i] - 'a';
        if (!t[p].trans[v]) t[p].trans[v] = ++ cT;
        p = t[p].trans[v];
    }
    End[id] = p;
}

void GetFail() {
    for (int i = 0; i < 26; i ++)
        t[0].trans[i] = 1;

    queue<int> q;
    q.push(1), t[1].fail = 0;

    while (q.size()) {
        int u = q.front(); q.pop();
        for (int i = 0; i < 26; i ++) {
            if (t[u].trans[i])
                t[t[u].trans[i]].fail = t[t[u].fail].trans[i], q.push(t[u].trans[i]);
            else
                t[u].trans[i] = t[t[u].fail].trans[i];
        }
    }
}

int tot, head[SIZE], ver[SIZE], Next[SIZE];

void addedge(int u, int v) {
    ver[++ tot] = v;    Next[tot] = head[u];    head[u] = tot;
}

int d[SIZE];
int size[SIZE];
int son[SIZE];

void dfs1(int u) {
    size[u] = 1;

    for (int i = head[u]; i; i = Next[i]) {
        int v = ver[i];
        d[v] = d[u] + 1;
        dfs1(v);
        size[u] += size[v];
        if (size[v] > size[son[u]]) son[u] = v;
    }
}

int ovo, dfn[SIZE];
int top[SIZE];

void dfs2(int u) {
    dfn[u] = ++ ovo;

    if (son[u]) {
        top[son[u]] = top[u];
        dfs2(son[u]);
    }

    for (int i = head[u]; i; i = Next[i]) {
        int v = ver[i];
        if (v == son[u]) continue; 
        top[v] = v;
        dfs2(v);
    }
}

int lca(int x, int y) {
    while (top[x] != top[y]) {
        if (d[top[x]] > d[top[y]]) swap(x, y);
        y = t[top[y]].fail;
    }
    if (d[x] > d[y]) swap(x, y);
    return x;
}

int c[SIZE];

void add(int x, int val) {
    for (; x <= cT; x += x & -x) c[x] += val;
}

int ask(int x) {
    int ans = 0;
    for (; x; x -= x & -x) ans += c[x];
    return ans;
}

int seq[SIZE];

bool cmp(int i, int j) {
    return dfn[i] < dfn[j];
}

int main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; i ++) {
        scanf("%s", S + 1);
        insert(S, i);
    }

    scanf("%d", &m);

    GetFail();

    for (int i = 2; i <= cT; i ++)
        addedge(t[i].fail, i); 

    d[1] = 1, dfs1(1);
    top[1] = 1, dfs2(1);

    while (m --) {
        int opt, x;
        scanf("%d", &opt);

        switch (opt) {
            case 1: {
                scanf("%s", S + 1);

                int p = 1, len = strlen(S + 1);
                for (int i = 1; i <= len; i ++) {
                    int v = S[i] - 'a';
                    p = t[p].trans[v];
                    seq[i] = p;
                }

                sort(seq + 1, seq + 1 + len, cmp);

                for (int i = 1; i <= len; i ++) {
                    int p = seq[i];
                    add(dfn[p], 1);
                } 

                for (int i = 1; i < len; i ++) {
                    int p = seq[i], q = seq[i + 1];
                    add(dfn[lca(p, q)], -1);
                } 

                break;
            }

            case 2: {
                scanf("%d", &x);
                int p = End[x];
                printf("%d\n", ask(dfn[p] + size[p] - 1) - ask(dfn[p] - 1));

                break;
            } 
        }
    } 

    return 0;
}