【题解】Standing Out from the Herd P

· · 题解

看到题解区里没有后缀数组的题解,我就来发一篇。

在阅读之前,你需要对后缀数组和 height 数组有一定的理解。

Part 1

如何求一个串的本质不同的子串个数?我们可以考虑用子串总数减去重复的子串个数。

处理出 height 后,按照字典序从小到大枚举每个后缀,设当前后缀的排名为 i,可以发现后缀 sa[i] 与后缀 sa[i-1]LCPheight[i],这也就是枚举时的前后两个串重复的部分,将其从答案中扣除即可。

所以一个串的本质不同的子串个数为:

\frac{n(n+1)}{2}-\sum_{i=2}^{n}height[i]

Part 2

如何求每个串的只属于自己的本质不同的子串个数?可以限求出每个子串的本质不同子串个数,在扣除公共的部分。

首先把每个串拼接起来,中间插入不同的分隔符,同时对每个部分进行染色。

我们仍然考虑按照字典序枚举,考虑两个相邻的后缀 sa[i-1]sa[i],如果它们属于不同的串,那么它们的 LCP 就应该被扣除。

但是现在我们就要考虑重复的问题,因为我们要从本质不同子串中扣除答案,重复的部分只能减去一次,我们可以通过下面的图形象地理解那些部分要被减掉。

我们用 A,B 来表示不同的串,柱状图的高度表示两个串的 height

i=3 时,出现了公共的部分,此时 A,B 都要减去 height[3]

i=5 时,又出现了公共部分,从图中显然可以看出与上一次减去的部分有重复,且重复的部分完全覆盖了现在的 A,B 串的公共部分,所以这次不能减去。

i=7 时,height[7]>height[5],所以我们这次减去 height[7]-height[5],只有在 height[5] 之上的部分是不重复的。为什么不是减去 height[7]-height[3] 呢?因为后缀 sa[3] 与后缀 sa[7]LCP,也就是它们相同的部分只有 height[5] 及以下的部分,上面都是不同的。

i=14 时,这次不是要减去 height[14]-height[7],而是减去 height[14]-height[11],因为 LCP(13,7)=height[11],只有这部分是与前面减掉的部分有重复。

Part 3

现在有一个大体的思路:按照字典序枚举,考虑两个相邻的后缀 sa[i-1]sa[i],如果它们属于不同的串,则减去它们的 LCP,同时考虑 LCP 与先前减去的部分有多少是重叠的,把这部分再加回去。我们可以维护对每个串维护一个单调栈,维护一个从栈底到栈顶单调递增的序列,表示对于这个串的每次减去的长度(不考虑重复应减去的长度,这个东西其实就是 height)。

显然对于一个串,会重复的部分只可能是 LCP 的部分。考虑 LCP 长度为区间中 height 的最小值的性质,其中不单调的部分不可能是 LCP 的长度,对答案没有贡献,所以我们只要维护单调栈,重复的部分就是栈顶。

同时,若当前要减去的地方与上一次减去的地方这个区间里有更小的 height (上图 i=14),公共的部分就变为这个 height,因此还要维护一个 ST 表。

Part 4

由于要为每个串维护一个单调栈,所以要用链表实现单调栈,或使用 std::stack<T, std::list<T>> 定义一个使用 std::list<T> 为的底层数据结构的栈。若没有使用这个定义,std::stack<T> 等价于 std::stack<T, std::deque<T>>std::deque<T> 本质上是一个块状数据结构,容易 MLE

#include <cstdio>
#include <cstring>
#include <stack>
#include <list>
#include <algorithm>
using namespace std;

constexpr int maxn = 2e5 + 10;

/** 单调栈 */
class SteadyStack {
public:
    void push(int x) {
        while (!st.empty() && st.top() >= x)
            st.pop();

        st.push(x);
    }

    int top() {
        return (st.empty() ? 0 : st.top());
    }
private:
    /** 注意 stack 的定义 */
    stack<int, list<int>> st;
};

class SparseTable {
public:
    void preprocess(int a[], int n) {
        logv[0] = -1;

        for (int i = 1; i <= n; ++i)
            logv[i] = logv[i / 2] + 1;

        for (int i = 1; i <= n; ++i)
            data[i][0] = a[i];

        for (int j = 1; j <= 17; ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                data[i][j] = min(data[i][j - 1], data[i + (1 << (j - 1))][j - 1]);
    }

    int query(int l, int r) {
        if (l > r)
            return 0x3f3f3f3f;

        int s = logv[r - l + 1];
        return min(data[l][s], data[r - (1 << s) + 1][s]);
    }
private:
    int data[maxn][18];
    int logv[maxn];
};

int n, len, dat[maxn * 2]; /** dat 为拼接后的串 */
char str[maxn];
int sa[maxn * 4], buf[2][maxn * 4], *rk, height[maxn * 4];
int cnt[maxn * 4], id[maxn * 4], *old, tmp[maxn * 4];
int col[maxn * 2]; /** col[i] 表示在拼接后的串中 i 的属于那个串 */
long long ans[maxn];
SteadyStack st[maxn];
SparseTable mi;
int las[maxn]; /** 上一次减去的位置 */

bool equal(int x, int y, int w) {
    return old[x] == old[y] && old[x + w] == old[y + w];
}

void preprocess(int dat[], int n, int m) {
    for (int i = 1; i <= m; ++i)
        cnt[i] = 0;

    for (int i = 1; i <= n; ++i)
        ++cnt[rk[i] = dat[i]];

    for (int i = 1; i <= m; ++i)
        cnt[i] += cnt[i - 1];

    for (int i = n; i >= 1; --i)
        sa[cnt[rk[i]]--] = i;

    for (int w = 1, p;; w *= 2, m = p) {
        p = 0;

        for (int i = n; i > n - w; --i)
            id[++p] = i;

        for (int i = 1; i <= n; ++i)
            if (sa[i] > w)
                id[++p] = sa[i] - w;

        for (int i = 1; i <= m; ++i)
            cnt[i] = 0;

        for (int i = 1; i <= n; ++i)
            ++cnt[tmp[i] = rk[id[i]]];

        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];

        for (int i = n; i >= 1; --i)
            sa[cnt[tmp[i]]--] = id[i];

        swap(rk, old);
        p = 0;

        for (int i = 1; i <= n; ++i)
            rk[sa[i]] = (equal(sa[i - 1], sa[i], w) ? p : ++p);

        if (p == n) {
            for (int i = 1; i <= n; ++i)
                sa[rk[i]] = i;

            break;
        }
    }

    for (int i = 1, j = 0; i <= n; ++i) {
        if (j)
            --j;

        while (dat[i + j] == dat[sa[rk[i] - 1] + j])
            ++j;

        height[rk[i]] = j;
    }
}

/** 拼接字符串并计算出本质不同子串个数 */
void add(char str[], int id) {
    int start = len, l = strlen(str);

    for (int i = 0; i < l; ++i) {
        dat[++len] = str[i];
        col[len] = id;
    }

    dat[++len] = id + 127;
    preprocess(dat + start, l, 127);

    ans[id] = 1ll * l * (l + 1) / 2;

    for (int i = 1; i <= l; ++i)
        ans[id] -= height[i];
}

/** 扣除重复部分 */
void remove(int id, int l, int p) {
    ans[id] -= max(l - min(st[id].top(), mi.query(las[id] + 1, p - 1)), 0);
    st[id].push(l);
}

int main() {
    rk = buf[0], old = buf[1];

    scanf("%d", &n);

    for (int i = 1; i <= n; ++i) {
        scanf("%s", str);
        add(str, i);
    }

    preprocess(dat, len, 127 + n);
    mi.preprocess(height, len);

    for (int i = 2; i <= len; ++i) {
        if (col[sa[i - 1]] != col[sa[i]]) {
            remove(col[sa[i - 1]], height[i], i);
            remove(col[sa[i]], height[i], i);
            las[col[sa[i - 1]]] = i - 1;
            las[col[sa[i]]] = i;
        }
    }

    for (int i = 1; i <= n; ++i)
        printf("%lld\n", ans[i]);

    return 0;
}