题解 P6139 【【模板】广义后缀自动机(广义SAM)】

· · 题解

后缀数组做法

前置知识

本题解法

我们知道,height数组的含义是排名相邻的两个后缀的最长公共前缀长度,也可以理解为对于一对公共前缀极长的后缀,我们会重复计算多少相同子串。这样,一个长度为n的字符串S,它本质不同的字符串个数就是

\frac{n(n + 1)}{2} - \sum ^ n _ {i = 1}h_i

那么,如果我们有多个字符串,我们在每两个字符串中间插入一个奇奇怪怪的字符,来防止height数组出现跨字符串的前缀,然后应用刚才那个规律,就可以求本质不同的字符串个数。假如有m个字符串,连接后总长是L,答案就是

\sum _ {i = 1} ^ m (\frac{n_i(n_i + 1)}{2}) - \sum _ {i = 1} ^ {L} h_i

代码

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;
typedef long long LL;
const int MAXN = 2e6;

int n, m = 'z', tot;
LL ans;
int rk[MAXN], sa[MAXN], tx[MAXN], tp[MAXN], ht[MAXN];
char ch[MAXN];
int len;
int a[MAXN];

void RSort() {
    for (int i = 0; i <= m; i++) tx[i] = 0;
    for (int i = 1; i <= n; i++) tx[rk[i]]++;
    for (int i = 1; i <= m; i++) tx[i] += tx[i - 1];
    for (int i = n; i >= 1; i--) sa[tx[rk[tp[i]]]--] = tp[i];
}

void Build() {
    for (int i = 1; i <= n; i++) {
        rk[i] = a[i];
        tp[i] = i;
    }
    RSort();
    for (int p = 1, w = 1; p < n; w <<= 1, m = p) {
        p = 0;
        for (int i = 1; i <= w; i++)
            tp[++p] = n - w + i;
        for (int i = 1; i <= n; i++) 
            if (sa[i] > w)
                tp[++p] = sa[i] - w;
        RSort();
        swap(rk, tp);
        rk[sa[1]] = p = 1;
        for (int i = 2; i <= n; i++)
            rk[sa[i]] = (tp[sa[i]] == tp[sa[i - 1]] && tp[sa[i] + w] == tp[sa[i - 1] + w]) ? p : ++p;
    }
}

void GetH() {
    for (int i = 1; i <= n; i++) rk[sa[i]] = i;
    int k = 0;
    for (int i = 1; i <= n; i++) {
        if (rk[i] == 1) continue;
        if (k) k--;
        int j = sa[rk[i] - 1];
        while (i + k <= n && j + k <= n && a[i + k] == a[j + k]) k++;
        ht[rk[i]] = k;
    }
}

int main() {
    scanf("%d", &tot);
    while (tot--) {
        scanf("%s", ch + 1);
        len = strlen(ch + 1);
        ans += (LL)len * (len + 1) / 2;
        for (int i = 1; i <= len; i++) a[n + i] = ch[i];
        n += len;
        a[++n] = ++m;
    }
    Build();
    GetH();
    for (int i = 1; i <= n; i++) ans = (ans - ht[i]);
    printf("%lld\n", ans);
    return 0;
}
/*
4
aa
ab
bac
caa
*/