题解 P4094 【[HEOI2016/TJOI2016]字符串】

· · 题解

SAM做法不一定要用反串构建SAM。给出一种正串建SAM的做法。

如果直接求询问的答案似乎很困难,因为 s[a \dots b] 的所有子串数量级是 O(n^2) 级别的,一个个算 LCP (即使是在SAM上) 肯定也是不行的。

考虑二分答案。易知答案的范围在 [0,\min \{ b-a+1,d-c+1\}] 之间,假设现在二分答案 mid ,即答案是否大于等于 mid 。若答案大于等于 mid ,那么根据题意,s[c \dots c+mid-1] 必定是 s[a \dots b] 的子串。 那么我们可以在SAM上找到表示 s[c \dots c+mid-1] 的点,假设为点 p ,然后查询点 p 包不包含 s[a \dots b] ,即 pendpos 集合在区间 [a+mid-1,b] 中有没有元素。这个判断用线段树合并即可完成。

那么如何求点 p 呢?根据SAM的Parent Tree的性质,若 uv 的祖先,那么 u 所能代表的字符串一定都是 v 所能代表的后缀。那么我们可以找到表示前缀 c[1 \dots c+mid-1] 的点,设其为 u ,然后不断往上跳

```cpp #include <vector> #include <stdio.h> #include <cstring> #include <iostream> #include <algorithm> using namespace std; inline bool islower (char &ch) { return ch >= 'a' && ch <= 'z'; } inline bool isdigit (char &ch) { return ch >= '0' && ch <= '9'; } inline int idx (char &ch) { return ch - 'a'; } // idx是每个字符在SAM中对应的编号(如'a'对应0) inline char gc () { static char buf[1048576], *p1 = buf, *p2 = buf; return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1048576, stdin), p1 == p2) ? EOF : *p1 ++ ; } int readi () { int ans = 0; char ch = gc(); while(!isdigit(ch)) ch = gc(); while(isdigit(ch)) ans = ans * 10 + ch - '0', ch = gc(); return ans; } int reads (char *S) { int p = 0; char ch = gc(); while(!islower(ch)) ch = gc(); while(islower(ch)) S[++ p] = ch, ch = gc(); S[p + 1] = 0; return p; } //以上全是快读 const int N = 1e5 + 5; int rt[N << 1]; // SAM中每个节点在线段树上的root struct SegmentTree { // 线段树合并 int sum[N * 80]; int lch[N * 80], rch[N * 80]; int cnt; inline void pushup (int u) { sum[u] = sum[lch[u]] + sum[rch[u]]; } void update (int &u, int x, int l, int r, int v) { if(!u) u = ++ cnt; if(l == r) { sum[u] += v; return; } int mid = l + r >> 1; if(mid >= x) update(lch[u], x, l, mid, v); else update(rch[u], x, mid + 1, r, v); pushup(u); } int merge (int u, int v, int l, int r) { if(!u || !v) return u | v; int w = ++ cnt; if(l == r) { sum[w] = sum[u] + sum[v]; return w; } int mid = l + r >> 1; lch[w] = merge(lch[u], lch[v], l, mid); rch[w] = merge(rch[u], rch[v], mid + 1, r); pushup(w); return w; } int query (int u, int ql, int qr, int l, int r) { if(!u) return 0; if(l >= ql && r <= qr) return sum[u]; int mid = l + r >> 1, ret = 0; if(mid >= ql) ret += query(lch[u], ql, qr, l, mid); if(mid < qr) ret += query(rch[u], ql, qr, mid + 1, r); return ret; } } seg; int n, q; int ed[N]; // ed[i]代表前缀s[1...i]在SAM中的对应哪个点 struct SAM { int ch[N << 1][26]; int fa[N << 1][19] /*倍增*/, len[N << 1]; vector <int> g[N << 1]; // Parent Tree int cnt, lst; SAM () { cnt = lst = 1; } void insert (int c) { // 标准SAM构建 int p = lst, np = ++ cnt; lst = np; len[np] = len[p] + 1; seg.update(rt[np], len[np], 1, n, 1); for(; p && !ch[p][c]; p = fa[p][0]) ch[p][c] = np; if(!p) fa[np][0] = 1; else { int q = ch[p][c]; if(len[q] == len[p] + 1) fa[np][0] = q; else { int nq = ++ cnt; memcpy(ch[nq], ch[q], sizeof ch[q]); fa[nq][0] = fa[q][0]; len[nq] = len[p] + 1; fa[np][0] = fa[q][0] = nq; for(; p && ch[p][c] == q; p = fa[p][0]) ch[p][c] = nq; } } } void buildSAM (char *S) { // 建SAM for(int i = 1; S[i]; ++ i) insert(idx(S[i])); } void updEndPos (char *S) { // ed的含义在上面(这个EndPos不是SAM中的endpos) //个人习惯在建完SAM之后再更新ed[i],因为SAM的结构可能随着字符的增多而改变之前的形态(endpos出现歧义时候的分裂) for(int i = 1, p = 1; S[i]; ++ i) { p = ch[p][idx(S[i])]; ed[i] = p; } } void buildParentTree () { // 建Parent Tree for(int i = 2; i <= cnt; ++ i) g[fa[i][0]].push_back(i); } void dfsParentTree (int u) { // dfs,预处理倍增和endpos集合 for(int i = 1; i < 19; ++ i) fa[u][i] = fa[fa[u][i - 1]][i - 1]; for(int i = 0; i < g[u].size(); ++ i) { int v = g[u][i]; dfsParentTree(v); rt[u] = seg.merge(rt[u], rt[v], 1, n); } } int findNode (int c, int mid) { // 找符合条件的点p(代表s[c...c+mid-1]的点) int p = ed[c + mid - 1]; for(int i = 18; ~i; -- i) if(fa[p][i] && len[fa[p][i]] >= mid) // 只要len[fa[p][i]]大于等于mid就往上跳 p = fa[p][i]; return p; } } sam; bool check (int a, int b, int c, int mid) { // 查询mid是否合法 int p = sam.findNode(c, mid); return (bool)(seg.query(rt[p], a + mid - 1, b, 1, n)); // s[c...c+mid-1]是否在s[a...b]出现过 } void solve () { int a, b, c, d; a = readi(); b = readi(); c = readi(); d = readi(); int l = 0, r = min(b - a + 1, d - c + 1); int mid; while(l < r) { mid = l + r + 1 >> 1; // 向上取整 if(check(a, b, c, mid)) l = mid; else r = mid - 1; } printf("%d\n", l); } char S[N]; int main () { n = readi(); q = readi(); reads(S); sam.buildSAM(S); sam.buildParentTree(); sam.dfsParentTree(1); sam.updEndPos(S); //全是预处理 while(q -- ) solve(); return 0; } ```