【题解】Standing Out from the Herd P
看到题解区里没有后缀数组的题解,我就来发一篇。
在阅读之前,你需要对后缀数组和
Part 1
如何求一个串的本质不同的子串个数?我们可以考虑用子串总数减去重复的子串个数。
处理出
所以一个串的本质不同的子串个数为:
Part 2
如何求每个串的只属于自己的本质不同的子串个数?可以限求出每个子串的本质不同子串个数,在扣除公共的部分。
首先把每个串拼接起来,中间插入不同的分隔符,同时对每个部分进行染色。
我们仍然考虑按照字典序枚举,考虑两个相邻的后缀
但是现在我们就要考虑重复的问题,因为我们要从本质不同子串中扣除答案,重复的部分只能减去一次,我们可以通过下面的图形象地理解那些部分要被减掉。
我们用
当
当
当
当
Part 3
现在有一个大体的思路:按照字典序枚举,考虑两个相邻的后缀
显然对于一个串,会重复的部分只可能是
同时,若当前要减去的地方与上一次减去的地方这个区间里有更小的 ST 表。
Part 4
由于要为每个串维护一个单调栈,所以要用链表实现单调栈,或使用 std::stack<T, std::list<T>> 定义一个使用 std::list<T> 为的底层数据结构的栈。若没有使用这个定义,std::stack<T> 等价于 std::stack<T, std::deque<T>>,std::deque<T> 本质上是一个块状数据结构,容易 MLE。
#include <cstdio>
#include <cstring>
#include <stack>
#include <list>
#include <algorithm>
using namespace std;
constexpr int maxn = 2e5 + 10;
/** 单调栈 */
class SteadyStack {
public:
void push(int x) {
while (!st.empty() && st.top() >= x)
st.pop();
st.push(x);
}
int top() {
return (st.empty() ? 0 : st.top());
}
private:
/** 注意 stack 的定义 */
stack<int, list<int>> st;
};
class SparseTable {
public:
void preprocess(int a[], int n) {
logv[0] = -1;
for (int i = 1; i <= n; ++i)
logv[i] = logv[i / 2] + 1;
for (int i = 1; i <= n; ++i)
data[i][0] = a[i];
for (int j = 1; j <= 17; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
data[i][j] = min(data[i][j - 1], data[i + (1 << (j - 1))][j - 1]);
}
int query(int l, int r) {
if (l > r)
return 0x3f3f3f3f;
int s = logv[r - l + 1];
return min(data[l][s], data[r - (1 << s) + 1][s]);
}
private:
int data[maxn][18];
int logv[maxn];
};
int n, len, dat[maxn * 2]; /** dat 为拼接后的串 */
char str[maxn];
int sa[maxn * 4], buf[2][maxn * 4], *rk, height[maxn * 4];
int cnt[maxn * 4], id[maxn * 4], *old, tmp[maxn * 4];
int col[maxn * 2]; /** col[i] 表示在拼接后的串中 i 的属于那个串 */
long long ans[maxn];
SteadyStack st[maxn];
SparseTable mi;
int las[maxn]; /** 上一次减去的位置 */
bool equal(int x, int y, int w) {
return old[x] == old[y] && old[x + w] == old[y + w];
}
void preprocess(int dat[], int n, int m) {
for (int i = 1; i <= m; ++i)
cnt[i] = 0;
for (int i = 1; i <= n; ++i)
++cnt[rk[i] = dat[i]];
for (int i = 1; i <= m; ++i)
cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i)
sa[cnt[rk[i]]--] = i;
for (int w = 1, p;; w *= 2, m = p) {
p = 0;
for (int i = n; i > n - w; --i)
id[++p] = i;
for (int i = 1; i <= n; ++i)
if (sa[i] > w)
id[++p] = sa[i] - w;
for (int i = 1; i <= m; ++i)
cnt[i] = 0;
for (int i = 1; i <= n; ++i)
++cnt[tmp[i] = rk[id[i]]];
for (int i = 1; i <= m; ++i)
cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i)
sa[cnt[tmp[i]]--] = id[i];
swap(rk, old);
p = 0;
for (int i = 1; i <= n; ++i)
rk[sa[i]] = (equal(sa[i - 1], sa[i], w) ? p : ++p);
if (p == n) {
for (int i = 1; i <= n; ++i)
sa[rk[i]] = i;
break;
}
}
for (int i = 1, j = 0; i <= n; ++i) {
if (j)
--j;
while (dat[i + j] == dat[sa[rk[i] - 1] + j])
++j;
height[rk[i]] = j;
}
}
/** 拼接字符串并计算出本质不同子串个数 */
void add(char str[], int id) {
int start = len, l = strlen(str);
for (int i = 0; i < l; ++i) {
dat[++len] = str[i];
col[len] = id;
}
dat[++len] = id + 127;
preprocess(dat + start, l, 127);
ans[id] = 1ll * l * (l + 1) / 2;
for (int i = 1; i <= l; ++i)
ans[id] -= height[i];
}
/** 扣除重复部分 */
void remove(int id, int l, int p) {
ans[id] -= max(l - min(st[id].top(), mi.query(las[id] + 1, p - 1)), 0);
st[id].push(l);
}
int main() {
rk = buf[0], old = buf[1];
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%s", str);
add(str, i);
}
preprocess(dat, len, 127 + n);
mi.preprocess(height, len);
for (int i = 2; i <= len; ++i) {
if (col[sa[i - 1]] != col[sa[i]]) {
remove(col[sa[i - 1]], height[i], i);
remove(col[sa[i]], height[i], i);
las[col[sa[i - 1]]] = i - 1;
las[col[sa[i]]] = i;
}
}
for (int i = 1; i <= n; ++i)
printf("%lld\n", ans[i]);
return 0;
}