题解 「TAOI-2」Ciallo~(∠・ω< )⌒★

· · 题解

昨天晚上看到室友在玩一个叫千恋万花的游戏,中间忘了,现在我才认清我室友自私自利的嘴脸,根本不值得我给他分享好东西。

一句话题意:求字符串 s 删去每个区间后字符串 t 出现的次数之和。

下文字符串下标从零开始编号。

删去一个区间等价于取一个前缀 pre 和一个后缀 suf,满足 |pre|+|suf|<|s|tpresuf 内部的情况是 trivial 的,我们可以一遍字符串匹配处理掉,考虑 t 横跨两边的情况。

考虑枚举 t 的哪一部分在 pre 中,设这一部分为 t_{0..i},则我们求出所有 s 中匹配 t_{0..i} 的左端点,设为 A_i,求出所有匹配 t_{i+1..|t|-1} 的右端点,设为 B_i,则对答案的贡献为 \sum\limits_{x \in A_i,y \in B_i}[y-x \ge |t|]

乍一看 \sum |A_i|\sum |B_i| 都是 O(n^2) 级别的,似乎没有什么高效算法,但是我们注意到 \forall i < j, A_j \subseteq A_i, B_i \subseteq B_j,于是我们考虑只记录 A_iB_ii 增加时发生的变化,可以发现总变化量是 O(n) 的,这样只要维护变化就可以维护出每个时候的 A_iB_i,然后每次 A_iB_i 变化时,\sum\limits_{x \in A_i,y \in B_i}[y-x \ge |t|] 的改变可以用一个树状数组统计。

考虑怎么维护 A_iB_i 的变化,我们对于每个 s 中的位置可以二分加哈希求出以这个位置为左端点能匹配的 t 的最长长度,这样就可以维护出 A_i 的变化,类似地,可以维护出 B_i 的变化。

总复杂度 O(n \log n)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int bas = 256, mod = 1e9 + 7;
int n, m;
string s, t;
int hasS[400005], hasT[400005];
int basPw[400005];
int GetHashS(int l, int r) { return (hasS[r] - 1ll * hasS[l - 1] * basPw[r - l + 1] % mod + mod) % mod; }
int GetHashT(int l, int r) { return (hasT[r] - 1ll * hasT[l - 1] * basPw[r - l + 1] % mod + mod) % mod; }
vector<int> lcg[400005], rcg[400005];
struct BIT {
    ll f[400005];
    void Modify(int i, int x) {
        for (; i <= n; i += i & -i) f[i] += x;
    }
    ll Query(int l, int r) {
        if (r < l) return 0;
        ll res = 0; l--;
        for (; r; r &= r - 1) res += f[r];
        for (; l; l &= l - 1) res -= f[l];
        return res;
    }
} tf, tg;
ll sum, ans;
void ModifyF(int i, int x) {
    tf.Modify(i, x);
    sum += tg.Query(1, max(i - m, 0)) * x;
}
void ModifyG(int i, int x) {
    tg.Modify(i, x);
    sum += tf.Query(min(i + m, n + 1), n) * x;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr); cout.tie(nullptr);
    cin >> s >> t;
    n = s.length(); m = t.length(); s = '$' + s; t = '$' + t;
    for (int i = 1; i <= n; i++) hasS[i] = (1ll * hasS[i - 1] * bas + s[i]) % mod;
    for (int i = 1; i <= m; i++) hasT[i] = (1ll * hasT[i - 1] * bas + t[i]) % mod;
    basPw[0] = 1;
    for (int i = 1; i <= max(n, m); i++) basPw[i] = 1ll * basPw[i - 1] * bas % mod;
    for (int i = 1; i <= n; i++) {
        int l = 1, r = min(m, n - i + 1), res = 0;
        while (l <= r) {
            int mid = l + r >> 1;
            if (GetHashS(i, i + mid - 1) == GetHashT(1, mid)) {
                res = mid;
                l = mid + 1;
            }
            else {
                r = mid - 1;
            }
        }
        lcg[res].emplace_back(i);
    }
    for (int i = 1; i <= n; i++) {
        int l = 1, r = min(m, i), res = 0;
        while (l <= r) {
            int mid = l + r >> 1;
            if (GetHashS(i - mid + 1, i) == GetHashT(m - mid + 1, m)) {
                res = mid;
                l = mid + 1;
            }
            else {
                r = mid - 1;
            }
        }
        rcg[m - res + 1].emplace_back(i);
    }
    for (int i = 0; i <= m; i++) for (int j : lcg[i]) ModifyG(j, 1);
    for (int i : rcg[1]) ModifyF(i, 1);
    for (int i = 1; i < m; i++) {
        for (int j : lcg[i - 1]) ModifyG(j, -1);
        for (int j : rcg[i + 1]) ModifyF(j, 1);
        ans += sum;
    }
    for (int i = 1; i <= n - m + 1; i++) {
        if (GetHashS(i, i + m - 1) == GetHashT(1, m)) {
            int x = i, y = i + m - 1;
            ans += 1ll * (n - y + 1) * (n - y) / 2;
            ans += 1ll * (x - 1) * x / 2;
        }
    }
    printf("%lld\n", ans);
    return 0;
}