P14220 [ICPC 2024 Kunming I] 生成字符串

· · 题解

首先如果每次给的是直接的一个串,那么建出所有串,正串和反串的 Trie,那么开头和结尾的限制相当于在两棵 Trie 上,集合串都要位于询问串的子树内,可以描述成二维偏序。加上时间维就是三维偏序,可以 O(q \log^2 q) 解决。

现在我们的问题是不能把整棵 Trie 建出来。考虑只建出所有串结尾结点的虚树。建虚树首先要将所有串按照字典序排序。单次比较的方法是每次用后缀数组求两个区间的 LCP,所以单次比较复杂度 O(l_1 + l_2),直接排序复杂度不对。考虑将所有串按照区间个数排序,然后逐个插入排序,具体实现就逐个插入 set 即可。

然后我们已经知道了每个点的 dfn 序,还要知道每个点的出栈时间。其实就是求出,每个串 s 在按 dfn 序排序的序列中,下一个不以 s 为前缀的串。求出相邻两个串的 LCP 后可以单调栈求出。

注意一些实现细节,比如相等的串要放在一起考虑。

最后时间复杂度 O(n \log n + q \log^2 q + \sum k_i \log q)

:::info[代码]

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;

const int maxn = 600100;
const int logn = 21;

int n, m, ans[maxn];
char s[maxn], t[maxn];

struct SA {
    int rk[maxn], id[maxn], old[maxn], cnt[maxn], sa[maxn], h[maxn], f[logn][maxn / 6];

    inline int qmin(int l, int r) {
        int k = __lg(r - l + 1);
        return min(f[k][l], f[k][r - (1 << k) + 1]);
    }

    inline int lcp(int x, int y) {
        if (x == y) {
            return n - x + 1;
        }
        x = rk[x];
        y = rk[y];
        if (x > y) {
            swap(x, y);
        }
        return qmin(x + 1, y);
    }

    inline void build(char *s) {
        int m = max(n, 127);
        for (int i = 1; i <= n; ++i) {
            rk[i] = s[i];
            ++cnt[rk[i]];
        }
        for (int i = 1; i <= m; ++i) {
            cnt[i] += cnt[i - 1];
        }
        for (int i = n; i; --i) {
            sa[cnt[rk[i]]--] = i;
        }
        for (int w = 1; w < n; w <<= 1) {
            int tot = 0;
            for (int i = n - w + 1; i <= n; ++i) {
                id[++tot] = i;
            }
            for (int i = 1; i <= n; ++i) {
                if (sa[i] > w) {
                    id[++tot] = sa[i] - w;
                }
            }
            for (int i = 1; i <= m; ++i) {
                cnt[i] = 0;
            }
            for (int i = 1; i <= n; ++i) {
                old[i] = rk[i];
                ++cnt[rk[id[i]]];
            }
            for (int i = 1; i <= m; ++i) {
                cnt[i] += cnt[i - 1];
            }
            for (int i = n; i; --i) {
                sa[cnt[rk[id[i]]]--] = id[i];
            }
            for (int i = 1, p = 0; i <= n; ++i) {
                if (old[sa[i]] == old[sa[i - 1]] && old[sa[i] + w] == old[sa[i - 1] + w]) {
                    rk[sa[i]] = p;
                } else {
                    rk[sa[i]] = ++p;
                }
            }
        }
        h[1] = 0;
        for (int i = 1, k = 0; i <= n; ++i) {
            if (rk[i] == 1) {
                continue;
            }
            if (k) {
                --k;
            }
            while (i + k <= n && sa[rk[i] - 1] + k <= n && s[i + k] == s[sa[rk[i] - 1] + k]) {
                ++k;
            }
            h[rk[i]] = k;
        }
        for (int i = 1; i <= n; ++i) {
            f[0][i] = h[i];
        }
        for (int j = 1; (1 << j) <= n; ++j) {
            for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
                f[j][i] = min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
            }
        }
    }
} A, B;

struct que {
    int o, t;
    vector<pii> a, b;
} qq[maxn / 6];

pair< vector<pii>, int > vc1[maxn >> 1], vc2[maxn >> 1];
int t1, t2, st1[maxn], ed1[maxn], st2[maxn], ed2[maxn], p[maxn], stk[maxn], top;
ll f[maxn];

struct cmp1 {
    inline bool operator () (const int &x, const int &y) const {
        auto &v1 = vc1[x].fst, &v2 = vc1[y].fst;
        int i = 0, j = 0, l1 = 0, l2 = 0;
        while (i < (int)v1.size() && j < (int)v2.size()) {
            int len = A.lcp(v1[i].fst + l1, v2[j].fst + l2);
            if (len < min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1)) {
                return s[v1[i].fst + l1 + len] < s[v2[j].fst + l2 + len];
            }
            len = min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1);
            if (len >= v1[i].scd - v1[i].fst - l1 + 1) {
                l1 = 0;
                ++i;
            } else {
                l1 += len;
            }
            if (len >= v2[j].scd - v2[j].fst - l2 + 1) {
                l2 = 0;
                ++j;
            } else {
                l2 += len;
            }
        }
        if (i < (int)v1.size()) {
            return 0;
        }
        if (j < (int)v2.size()) {
            return 1;
        }
        return x < y;
    }
};

struct cmp2 {
    inline bool operator () (const int &x, const int &y) const {
        auto &v1 = vc2[x].fst, &v2 = vc2[y].fst;
        int i = 0, j = 0, l1 = 0, l2 = 0;
        while (i < (int)v1.size() && j < (int)v2.size()) {
            int len = B.lcp(v1[i].fst + l1, v2[j].fst + l2);
            if (len < min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1)) {
                return t[v1[i].fst + l1 + len] < t[v2[j].fst + l2 + len];
            }
            len = min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1);
            if (len >= v1[i].scd - v1[i].fst - l1 + 1) {
                l1 = 0;
                ++i;
            } else {
                l1 += len;
            }
            if (len >= v2[j].scd - v2[j].fst - l2 + 1) {
                l2 = 0;
                ++j;
            } else {
                l2 += len;
            }
        }
        if (i < (int)v1.size()) {
            return 0;
        }
        if (j < (int)v2.size()) {
            return 1;
        }
        return x < y;
    }
};

inline ll lcp1(int x, int y) {
    auto &v1 = vc1[x].fst, &v2 = vc1[y].fst;
    int i = 0, j = 0, l1 = 0, l2 = 0;
    ll k = 0;
    while (i < (int)v1.size() && j < (int)v2.size()) {
        int len = A.lcp(v1[i].fst + l1, v2[j].fst + l2);
        k += min({len, v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1});
        if (len < min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1)) {
            return k;
        }
        len = min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1);
        if (len >= v1[i].scd - v1[i].fst - l1 + 1) {
            l1 = 0;
            ++i;
        } else {
            l1 += len;
        }
        if (len >= v2[j].scd - v2[j].fst - l2 + 1) {
            l2 = 0;
            ++j;
        } else {
            l2 += len;
        }
    }
    return k;
}

inline ll lcp2(int x, int y) {
    auto &v1 = vc2[x].fst, &v2 = vc2[y].fst;
    int i = 0, j = 0, l1 = 0, l2 = 0;
    ll k = 0;
    while (i < (int)v1.size() && j < (int)v2.size()) {
        int len = B.lcp(v1[i].fst + l1, v2[j].fst + l2);
        k += min({len, v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1});
        if (len < min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1)) {
            return k;
        }
        len = min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1);
        if (len >= v1[i].scd - v1[i].fst - l1 + 1) {
            l1 = 0;
            ++i;
        } else {
            l1 += len;
        }
        if (len >= v2[j].scd - v2[j].fst - l2 + 1) {
            l2 = 0;
            ++j;
        } else {
            l2 += len;
        }
    }
    return k;
}

inline bool eq1(int x, int y) {
    auto &v1 = vc1[x].fst, &v2 = vc1[y].fst;
    int i = 0, j = 0, l1 = 0, l2 = 0;
    while (i < (int)v1.size() && j < (int)v2.size()) {
        int len = A.lcp(v1[i].fst + l1, v2[j].fst + l2);
        if (len < min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1)) {
            return 0;
        }
        len = min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1);
        if (len >= v1[i].scd - v1[i].fst - l1 + 1) {
            l1 = 0;
            ++i;
        } else {
            l1 += len;
        }
        if (len >= v2[j].scd - v2[j].fst - l2 + 1) {
            l2 = 0;
            ++j;
        } else {
            l2 += len;
        }
    }
    if (i < (int)v1.size() || j < (int)v2.size()) {
        return 0;
    }
    return 1;
}

inline bool eq2(int x, int y) {
    auto &v1 = vc2[x].fst, &v2 = vc2[y].fst;
    int i = 0, j = 0, l1 = 0, l2 = 0;
    while (i < (int)v1.size() && j < (int)v2.size()) {
        int len = B.lcp(v1[i].fst + l1, v2[j].fst + l2);
        if (len < min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1)) {
            return 0;
        }
        len = min(v1[i].scd - v1[i].fst - l1 + 1, v2[j].scd - v2[j].fst - l2 + 1);
        if (len >= v1[i].scd - v1[i].fst - l1 + 1) {
            l1 = 0;
            ++i;
        } else {
            l1 += len;
        }
        if (len >= v2[j].scd - v2[j].fst - l2 + 1) {
            l2 = 0;
            ++j;
        } else {
            l2 += len;
        }
    }
    if (i < (int)v1.size() || j < (int)v2.size()) {
        return 0;
    }
    return 1;
}

struct node {
    int x, y, z, o, i;
    node(int _x = 0, int _y = 0, int _z = 0, int _o = 0, int _i = 0) : x(_x), y(_y), z(_z), o(_o), i(_i) {}
} a[maxn], b[maxn];

namespace BIT {
    int c[maxn];

    inline void update(int x, int d) {
        for (int i = x; i <= m; i += (i & (-i))) {
            c[i] += d;
        }
    }

    inline int query(int x) {
        int res = 0;
        for (int i = x; i; i -= (i & (-i))) {
            res += c[i];
        }
        return res;
    }

    inline void clear(int x) {
        for (int i = x; i <= m; i += (i & (-i))) {
            c[i] = 0;
        }
    }
}

void cdq(int l, int r) {
    if (l == r) {
        return;
    }
    int mid = (l + r) >> 1;
    int i = l, j = mid + 1, k = 0;
    cdq(l, mid);
    cdq(mid + 1, r);
    while (i <= mid && j <= r) {
        if (a[i].y <= a[j].y) {
            if (!a[i].i) {
                BIT::update(a[i].z, a[i].o);
            }
            b[++k] = a[i++];
        } else {
            if (a[j].i) {
                ans[a[j].i] += BIT::query(a[j].z) * a[j].o;
            }
            b[++k] = a[j++];
        }
    }
    while (i <= mid) {
        if (!a[i].i) {
            BIT::update(a[i].z, a[i].o);
        }
        b[++k] = a[i++];
    }
    while (j <= r) {
        if (a[j].i) {
            ans[a[j].i] += BIT::query(a[j].z) * a[j].o;
        }
        b[++k] = a[j++];
    }
    for (int i = l; i <= mid; ++i) {
        if (!a[i].i) {
            BIT::clear(a[i].z);
        }
    }
    for (int i = 1; i <= k; ++i) {
        a[l + i - 1] = b[i];
    }
}

void solve() {
    scanf("%d%d%s", &n, &m, s + 1);
    A.build(s);
    for (int i = 1; i <= n; ++i) {
        t[i] = s[n - i + 1];
    }
    B.build(t);
    for (int i = 1, k; i <= m; ++i) {
        char o[9];
        scanf("%s", o);
        if (o[0] == '+') {
            qq[i].o = 1;
            scanf("%d", &k);
            while (k--) {
                int l, r;
                scanf("%d%d", &l, &r);
                qq[i].a.pb(l, r);
            }
        } else if (o[0] == '-') {
            qq[i].o = 2;
            scanf("%d", &qq[i].t);
        } else {
            qq[i].o = 3;
            scanf("%d", &k);
            while (k--) {
                int l, r;
                scanf("%d%d", &l, &r);
                qq[i].a.pb(l, r);
            }
            scanf("%d", &k);
            while (k--) {
                int l, r;
                scanf("%d%d", &l, &r);
                qq[i].b.pb(l, r);
            }
        }
    }
    for (int i = 1; i <= m; ++i) {
        if (qq[i].o == 1) {
            vc1[++t1] = mkp(qq[i].a, i);
        } else if (qq[i].o == 3) {
            vc1[++t1] = mkp(qq[i].a, i);
        }
    }
    for (int i = 1; i <= t1; ++i) {
        p[i] = i;
    }
    sort(p + 1, p + t1 + 1, [&](const int &i, const int &j) {
        return vc1[i].fst.size() < vc1[j].fst.size();
    });
    set<int, cmp1> S1;
    for (int i = 1; i <= t1; ++i) {
        S1.insert(p[i]);
    }
    int tot = 0;
    for (int i : S1) {
        p[++tot] = i;
    }
    for (int i = 1, j = 1; i <= tot; i = (++j)) {
        while (j < tot && eq1(p[j], p[j + 1])) {
            ++j;
        }
        for (int k = i; k <= j; ++k) {
            st1[vc1[p[k]].scd] = i;
        }
    }
    for (int i = 2; i <= tot; ++i) {
        f[i] = lcp1(p[i - 1], p[i]);
    }
    for (int i = tot, j = tot; i; i = (--j)) {
        while (j > 1 && eq1(p[j - 1], p[j])) {
            --j;
        }
        ll s = 0;
        for (pii _ : vc1[p[i]].fst) {
            s += _.scd - _.fst + 1;
        }
        while (top && f[stk[top]] >= s) {
            --top;
        }
        for (int k = j; k <= i; ++k) {
            ed1[vc1[p[k]].scd] = (top ? stk[top] - 1 : tot);
        }
        stk[++top] = j;
    }
    for (int i = 1; i <= m; ++i) {
        if (qq[i].o == 1) {
            vector<pii> vc = qq[i].a;
            reverse(vc.begin(), vc.end());
            for (pii &p : vc) {
                p.fst = n - p.fst + 1;
                p.scd = n - p.scd + 1;
                swap(p.fst, p.scd);
            }
            vc2[++t2] = mkp(vc, i);
        } else if (qq[i].o == 3) {
            vector<pii> vc = qq[i].b;
            reverse(vc.begin(), vc.end());
            for (pii &p : vc) {
                p.fst = n - p.fst + 1;
                p.scd = n - p.scd + 1;
                swap(p.fst, p.scd);
            }
            vc2[++t2] = mkp(vc, i);
        }
    }
    for (int i = 1; i <= t2; ++i) {
        p[i] = i;
    }
    sort(p + 1, p + t2 + 1, [&](const int &i, const int &j) {
        return vc2[i].fst.size() < vc2[j].fst.size();
    });
    set<int, cmp2> S2;
    for (int i = 1; i <= t2; ++i) {
        S2.insert(p[i]);
    }
    tot = 0;
    for (int i : S2) {
        p[++tot] = i;
    }
    for (int i = 1, j = 1; i <= tot; i = (++j)) {
        while (j < tot && eq2(p[j], p[j + 1])) {
            ++j;
        }
        for (int k = i; k <= j; ++k) {
            st2[vc2[p[k]].scd] = i;
        }
    }
    for (int i = 2; i <= tot; ++i) {
        f[i] = lcp2(p[i - 1], p[i]);
    }
    top = 0;
    for (int i = tot, j = tot; i; i = (--j)) {
        while (j > 1 && eq2(p[j - 1], p[j])) {
            --j;
        }
        ll s = 0;
        for (pii _ : vc2[p[i]].fst) {
            s += _.scd - _.fst + 1;
        }
        while (top && f[stk[top]] >= s) {
            --top;
        }
        for (int k = j; k <= i; ++k) {
            ed2[vc2[p[k]].scd] = (top ? stk[top] - 1 : tot);
        }
        stk[++top] = j;
    }
    tot = 0;
    for (int i = 1; i <= m; ++i) {
        if (qq[i].o == 1) {
            a[++tot] = node(i, st1[i], st2[i], 1, 0);
        } else if (qq[i].o == 2) {
            int j = qq[i].t;
            a[++tot] = node(i, st1[j], st2[j], -1, 0);
        } else {
            a[++tot] = node(i, ed1[i], ed2[i], 1, i);
            a[++tot] = node(i, st1[i] - 1, ed2[i], -1, i);
            a[++tot] = node(i, ed1[i], st2[i] - 1, -1, i);
            a[++tot] = node(i, st1[i] - 1, st2[i] - 1, 1, i);
        }
    }
    sort(a + 1, a + tot + 1, [&](const node &a, const node &b) {
        if (a.x != b.x) {
            return a.x < b.x;
        } else if (a.y != b.y) {
            return a.y < b.y;
        } else if (a.z != b.z) {
            return a.z < b.z;
        } else {
            return a.i < b.i;
        }
    });
    cdq(1, tot);
    for (int i = 1; i <= m; ++i) {
        if (qq[i].o == 3) {
            printf("%d\n", ans[i]);
        }
    }
}

int main() {
    int T = 1;
    // scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}

:::