题解 SP1812 【LCS2 - Longest Common Substring II 】
更好阅读体验
题目概述
求
题解
对于最短字符串建立SAM。
考虑将其他字符串放到这个SAM上匹配,并记录mx[i]表示以i为结尾的字符串最长匹配长度。
从根开始向后匹配。如果自动机上可以向后匹配字符,就走一条转一边,转移到下一个结点;否则跳parent树,知道存在这个字符的转移边为止。
如果永远没有,那么从根开始重新匹配。每匹配到一个节点,就更新mx[i]。
匹配完再对于每一个mx[fa[i]]从mx[i]更新。这是因为一个节点的par必然在他的儿子中出现,但却有可能没有匹配到par这个节点。
容易证明这样每个结点都有最长匹配值。
再记录mn[i]表示i结点在每一个字符串中的最长匹配长度的最小值。
那么所有结点的mn[i]就是答案。
注:之所以选取最短字符串建立
SAM是因为每次匹配后都要遍历整个SAM并更新SAM每个节点的mx值,如果不选取最短串有可能复杂度并不是\mathcal O(\sum s_i) 的*代码
#include <cstdio> #include <cstring> #include <iostream> using namespace std; const int N = 1e5 + 5, NODE = 2 * N; int n, len[11], mnId, mx[NODE]; char str[11][N]; struct SAM { int ch[NODE][26], len[NODE], fa[NODE], mx[NODE], las, t, buc[N], od[NODE], mn[NODE]; SAM() { las = t = 1; } void ins(int c) { int p = las, np; fa[las = np = ++t] = 1; len[t] = len[p] + 1; for (; p && !ch[p][c]; p = fa[p]) ch[p][c] = np; if (p) { int q = ch[p][c], nq; if (len[p] + 1 == len[q]) fa[np] = q; else { fa[nq = ++t] = fa[q]; len[t] = len[p] + 1; memcpy(ch[nq], ch[q], sizeof(ch[nq])); for (fa[np] = fa[q] = nq; ch[p][c] == q; p = fa[p]) ch[p][c] = nq; } } } void init(int le) { for (int i = 1; i <= t; ++i) ++buc[len[i]]; for (int i = 1; i <= le; ++i) //这里 n并不是字符串长度... buc[i] += buc[i - 1]; for (int i = 1; i <= t; ++i) od[buc[len[i]]--] = i; for (int i = 1; i <= t; ++i) mn[i] = 0x3f3f3f3f; } void calc(char *s, int n) { int now = 1, l = 0; for (int i = 1; i <= t; ++i) mx[i] = 0; for (int i = 1; i <= n; ++i) { while (now && !ch[now][s[i] - 'a']) now = fa[now], l = len[now]; if (now) { now = ch[now][s[i] - 'a']; mx[now] = max(mx[now], ++l); } else { now = 1; l = 0; } } for (int i = 1; i <= t; ++i) { mx[fa[od[i]]] = min(max(mx[fa[od[i]]], mx[od[i]]), len[fa[od[i]]]); mn[od[i]] = min(mn[od[i]], mx[od[i]]); } } } sam; int main() { mnId = 1; while (~scanf("%s", str[++n] + 1)) { len[n] = strlen(str[n] + 1); if (len[mnId] > len[n]) mnId = n; } --n; for (int i = 1; i <= len[mnId]; ++i) sam.ins(str[mnId][i] - 'a'); sam.init(len[mnId]); for (int i = 1; i <= n; ++i) if (i != mnId) sam.calc(str[i], len[i]); int ans = 0; for (int i = 1; i <= sam.t; ++i) ans = max(ans, sam.mn[i]); printf("%d\n", ans); return 0; }