P10716 题解

· · 题解

一个 A 合法的充要条件为:

建出失配树后,发现合法的 A 在树上组成一条某个点 u 到根的链,且 ui 的祖先。因此我们若知道 u,答案就是 dep_u

考虑倍增求出 u。相当于要 check 这样一个问题:

考虑预处理出每个前缀 S_{1 \sim u} 对应的串在整个串中的不重叠出现位置。相当于每次找到一个位置后面第一个后缀,使得它与整个串的 LCP 长度 \ge u,然后跳到它。跳的步数最多是 \sum\limits_{i = 1}^n \frac{n}{i} = O(n \log n) 所以可以暴力跳。

一个后缀与整个串的 LCP 长度可以想到 Z 函数。用并查集维护链表的 trick,从小到大枚举 u,处理完 u 后把 z_p = u 的所有位置 p 删了,使得处理到 u 时,并查集中 z_p \ge u 的位置为代表元。这样即可 O(\alpha(n)) 求出一个位置后面第一个 p 使得 z_p \ge up

注意特判 k = 1

总时间复杂度 O(q \log n + n \alpha(n) \log n)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const int logn = 20;

int n, m, fail[maxn], dep[maxn], z[maxn], st[logn][maxn], f[logn][maxn], fa[maxn];
char s[maxn];
vector<int> vc[maxn], cv[maxn];

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

void solve() {
    scanf("%d%s%d", &n, s + 1, &m);
    for (int i = 2, j = 0; i <= n; ++i) {
        while (j && s[i] != s[j + 1]) {
            j = fail[j];
        }
        j += (s[i] == s[j + 1]);
        fail[i] = j;
    }
    for (int i = 1; i <= n; ++i) {
        f[0][i] = fail[i];
        dep[i] = dep[fail[i]] + 1;
    }
    for (int j = 1; j <= 19; ++j) {
        for (int i = 1; i <= n; ++i) {
            f[j][i] = f[j - 1][f[j - 1][i]];
        }
    }
    z[1] = n;
    for (int i = 2, l = 0, r = 0; i <= n; ++i) {
        if (i <= r) {
            z[i] = min(z[i - l + 1], r - i + 1);
        }
        while (i + z[i] <= n && s[z[i] + 1] == s[i + z[i]]) {
            ++z[i];
        }
        if (i + z[i] - 1 > r) {
            l = i;
            r = i + z[i] - 1;
        }
    }
    for (int i = 1; i <= n; ++i) {
        cv[z[i]].pb(i);
        fa[i] = (z[i] ? i : i + 1);
    }
    fa[n + 1] = n + 1;
    for (int i = 1; i <= n; ++i) {
        vc[i].pb(i);
        int p = i;
        while (p + i <= n) {
            int q = find(p + 1);
            if (q == n + 1) {
                break;
            }
            p = q + i - 1;
            vc[i].pb(p);
        }
        for (int j : cv[i]) {
            fa[j] = j + 1;
        }
    }
    while (m--) {
        int x, y;
        scanf("%d%d", &x, &y);
        if (y == 1) {
            puts("1");
            continue;
        }
        int t = x;
        for (int i = 19; ~i; --i) {
            if (f[i][x] && ((int)vc[f[i][x]].size() < y || vc[f[i][x]][y - 1] > t)) {
                x = f[i][x];
            }
        }
        x = f[0][x];
        printf("%d\n", dep[x]);
    }
}

int main() {
    int T = 1;
    // scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}