题解 P4094 【[HEOI2016/TJOI2016]字符串】
Forwarcl
·
·
题解
SAM做法不一定要用反串构建SAM。给出一种正串建SAM的做法。
如果直接求询问的答案似乎很困难,因为 s[a \dots b] 的所有子串数量级是 O(n^2) 级别的,一个个算 LCP (即使是在SAM上) 肯定也是不行的。
考虑二分答案。易知答案的范围在 [0,\min \{ b-a+1,d-c+1\}] 之间,假设现在二分答案 mid ,即答案是否大于等于 mid 。若答案大于等于 mid ,那么根据题意,s[c \dots c+mid-1] 必定是 s[a \dots b] 的子串。 那么我们可以在SAM上找到表示 s[c \dots c+mid-1] 的点,假设为点 p ,然后查询点 p 包不包含 s[a \dots b] ,即 p 的 endpos 集合在区间 [a+mid-1,b] 中有没有元素。这个判断用线段树合并即可完成。
那么如何求点 p 呢?根据SAM的Parent Tree的性质,若 u 是 v 的祖先,那么 u 所能代表的字符串一定都是 v 所能代表的后缀。那么我们可以找到表示前缀 c[1 \dots c+mid-1] 的点,设其为 u ,然后不断往上跳
```cpp
#include <vector>
#include <stdio.h>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
inline bool islower (char &ch) { return ch >= 'a' && ch <= 'z'; }
inline bool isdigit (char &ch) { return ch >= '0' && ch <= '9'; }
inline int idx (char &ch) { return ch - 'a'; } // idx是每个字符在SAM中对应的编号(如'a'对应0)
inline char gc () {
static char buf[1048576], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1048576, stdin), p1 == p2) ? EOF : *p1 ++ ;
}
int readi () {
int ans = 0; char ch = gc();
while(!isdigit(ch)) ch = gc();
while(isdigit(ch)) ans = ans * 10 + ch - '0', ch = gc();
return ans;
}
int reads (char *S) {
int p = 0; char ch = gc();
while(!islower(ch)) ch = gc();
while(islower(ch)) S[++ p] = ch, ch = gc();
S[p + 1] = 0; return p;
}
//以上全是快读
const int N = 1e5 + 5;
int rt[N << 1]; // SAM中每个节点在线段树上的root
struct SegmentTree { // 线段树合并
int sum[N * 80];
int lch[N * 80], rch[N * 80];
int cnt;
inline void pushup (int u) { sum[u] = sum[lch[u]] + sum[rch[u]]; }
void update (int &u, int x, int l, int r, int v) {
if(!u) u = ++ cnt;
if(l == r) { sum[u] += v; return; }
int mid = l + r >> 1;
if(mid >= x) update(lch[u], x, l, mid, v);
else update(rch[u], x, mid + 1, r, v);
pushup(u);
}
int merge (int u, int v, int l, int r) {
if(!u || !v) return u | v;
int w = ++ cnt;
if(l == r) { sum[w] = sum[u] + sum[v]; return w; }
int mid = l + r >> 1;
lch[w] = merge(lch[u], lch[v], l, mid);
rch[w] = merge(rch[u], rch[v], mid + 1, r);
pushup(w); return w;
}
int query (int u, int ql, int qr, int l, int r) {
if(!u) return 0;
if(l >= ql && r <= qr) return sum[u];
int mid = l + r >> 1, ret = 0;
if(mid >= ql) ret += query(lch[u], ql, qr, l, mid);
if(mid < qr) ret += query(rch[u], ql, qr, mid + 1, r);
return ret;
}
} seg;
int n, q;
int ed[N]; // ed[i]代表前缀s[1...i]在SAM中的对应哪个点
struct SAM {
int ch[N << 1][26];
int fa[N << 1][19] /*倍增*/, len[N << 1];
vector <int> g[N << 1]; // Parent Tree
int cnt, lst;
SAM () { cnt = lst = 1; }
void insert (int c) { // 标准SAM构建
int p = lst, np = ++ cnt; lst = np;
len[np] = len[p] + 1; seg.update(rt[np], len[np], 1, n, 1);
for(; p && !ch[p][c]; p = fa[p][0])
ch[p][c] = np;
if(!p) fa[np][0] = 1;
else {
int q = ch[p][c];
if(len[q] == len[p] + 1) fa[np][0] = q;
else {
int nq = ++ cnt;
memcpy(ch[nq], ch[q], sizeof ch[q]);
fa[nq][0] = fa[q][0];
len[nq] = len[p] + 1;
fa[np][0] = fa[q][0] = nq;
for(; p && ch[p][c] == q; p = fa[p][0])
ch[p][c] = nq;
}
}
}
void buildSAM (char *S) { // 建SAM
for(int i = 1; S[i]; ++ i)
insert(idx(S[i]));
}
void updEndPos (char *S) { // ed的含义在上面(这个EndPos不是SAM中的endpos)
//个人习惯在建完SAM之后再更新ed[i],因为SAM的结构可能随着字符的增多而改变之前的形态(endpos出现歧义时候的分裂)
for(int i = 1, p = 1; S[i]; ++ i) {
p = ch[p][idx(S[i])];
ed[i] = p;
}
}
void buildParentTree () { // 建Parent Tree
for(int i = 2; i <= cnt; ++ i)
g[fa[i][0]].push_back(i);
}
void dfsParentTree (int u) { // dfs,预处理倍增和endpos集合
for(int i = 1; i < 19; ++ i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(int i = 0; i < g[u].size(); ++ i) {
int v = g[u][i];
dfsParentTree(v);
rt[u] = seg.merge(rt[u], rt[v], 1, n);
}
}
int findNode (int c, int mid) { // 找符合条件的点p(代表s[c...c+mid-1]的点)
int p = ed[c + mid - 1];
for(int i = 18; ~i; -- i)
if(fa[p][i] && len[fa[p][i]] >= mid) // 只要len[fa[p][i]]大于等于mid就往上跳
p = fa[p][i];
return p;
}
} sam;
bool check (int a, int b, int c, int mid) { // 查询mid是否合法
int p = sam.findNode(c, mid);
return (bool)(seg.query(rt[p], a + mid - 1, b, 1, n)); // s[c...c+mid-1]是否在s[a...b]出现过
}
void solve () {
int a, b, c, d; a = readi(); b = readi(); c = readi(); d = readi();
int l = 0, r = min(b - a + 1, d - c + 1);
int mid;
while(l < r) {
mid = l + r + 1 >> 1; // 向上取整
if(check(a, b, c, mid)) l = mid;
else r = mid - 1;
}
printf("%d\n", l);
}
char S[N];
int main () {
n = readi(); q = readi();
reads(S);
sam.buildSAM(S);
sam.buildParentTree();
sam.dfsParentTree(1);
sam.updEndPos(S);
//全是预处理
while(q -- ) solve();
return 0;
}
```