SP8093 JZPGYZ - Sevenk Love Oimaster 题解

· · 题解

博客食用更佳

Description

给定 n 个模板串,以及 m 个查询串。

依次查询每一个查询串是多少个模板串的子串。

洛谷传送门

Solution

发现题解区好多用 SAM 的,这里提供一篇AC自动机做法的题解。

我们先对询问串建 AC 自动机,然后计算每一个文本串对每一个询问串的贡献。

那么这个贡献要如何计算呢?

考虑对于一个文本串 S。我们先枚举它的每一位,相当于是在枚举前缀。

从每一位跳 fail 指针时,每跳到一个节点,这个节点所代表的字符串都是 S 的一个后缀,也就是 S 的一个字串。

现在我们要求的就是 S 在 AC 自动机上跳到的每一个点到根的链上的点集并(多读几遍吧……)。

容易想到树上差分来维护,即在 x 点,y 点上 +1,lca(x, y) 上 -1。

但暴力是 n^2 的,考虑优化。我们可以把 fail 树的 dfs 序计算出来,然后就变成了区间加以及区间求和。

所以直接用线段树或树状数组维护一下即可。

我用的树状数组。

Code

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

using namespace std;

inline int read(){
    int x = 0; char ch = getchar();
    while(ch < '0' || ch > '9') ch = getchar();
    while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x;
}

const int N = 1e5 + 10;
const int M = 4e5 + 5;
int n, m;
int trie[M][27], tot = 1, End[M];
char s[M], tmp[10010][N];
int fail[M];

inline void insert(char s[], int id){
    int len = strlen(s + 1), now = 1;
    for(int i = 1; i <= len; i++){
        int c = s[i] - 'a';
        if(!trie[now][c]) trie[now][c] = ++tot;
        now = trie[now][c];
    }
    End[id] = now;
}

inline void build(){
    queue <int> q;
    for(int i = 0; i < 26; i++)
        trie[0][i] = 1;
    q.push(1);
    fail[1] = 0;
    while(!q.empty()){
        int now = q.front();
        q.pop();
        for(int i = 0; i < 26; i++){
            if(trie[now][i]) fail[trie[now][i]] = trie[fail[now]][i], q.push(trie[now][i]);
            else trie[now][i] = trie[fail[now]][i];
        }
    }
}

struct node{
    int v, nxt;
}edge[M];
int head[M], num;
int siz[M], son[M], dep[M];
int top[M], dfn[M], cnt;
int seq[M];

inline bool cmp(int a, int b){
    return dfn[a] < dfn[b];
}

inline void add_edge(int x, int y){
    edge[++num] = (node){y, head[x]};
    head[x] = num;
}

inline void dfs1(int x){
    dep[x] = dep[fail[x]] + 1, siz[x] = 1;
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        dfs1(y);
        siz[x] += siz[y];
        if(!son[x] || siz[y] > siz[son[x]])
            son[x] = y;
    }
}

inline void dfs2(int x, int topfa){
    top[x] = topfa, dfn[x] = ++cnt;
    if(!son[x]) return;
    dfs2(son[x], topfa);
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == son[x]) continue;
        dfs2(y, y);
    }
}

inline int lca(int x, int y){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fail[top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}

struct BIT{
    int c[M];
    void add(int x, int y){
        for(; x <= tot; x += x & (-x)) c[x] += y;
    }
    int query(int x){
        int res = 0;
        for(; x > 0; x -= x & (-x)) res += c[x];
        return res;
    }
}t;

inline void update(char s[]){
    int len = strlen(s + 1);
    int now = 1;
    for(int i = 1; i <= len; i++){
        now = trie[now][s[i] - 'a'];
        seq[i] = now;
    }
    sort(seq + 1, seq + 1 + len, cmp);
    for(int i = 1; i <= len; i++) t.add(dfn[seq[i]], 1);
    for(int i = 1; i < len; i++) t.add(dfn[lca(seq[i], seq[i + 1])], -1);
}

int main(){
    n = read(), m = read();
    for(int i = 1; i <= n; i++)
        scanf("%s", tmp[i] + 1);
    for(int i = 1; i <= m; i++){
        scanf("%s", s + 1);
        insert(s, i);
    }
    build();
    for(int i = 2; i <= tot; i++)
        add_edge(fail[i], i);
    dfs1(1), dfs2(1, 1);
    for(int i = 1; i <= n; i++)
        update(tmp[i]);
    for(int i = 1; i <= m; i++){
        int x = End[i];
        printf("%d\n", t.query(dfn[x] + siz[x] - 1) - t.query(dfn[x] - 1));
    }
    return 0;
}

End