求区间串复制 k 遍跨过边界的本质不同子串个数的在线单 log 回答询问算法

· · 算法·理论

sto @寻逍遥2006 orz 提供 idea

sto @Aaronwrq orz 提供代码

背景

逍遥老师某天发现一个字符串复制 k 遍后的本质不同子串个数从 k=2 开始就是等差数列,于是尝试做区间串赋值 k 遍本质不同子串个数,然而我最后只得到了求跨过边界的本质不同子串个数的算法。

约定

定义字符串的下标从 1 开始。

定义 |S| 为字符串 S 的长度。

对于字符串 ST,定义 S+T 为将 T 拼接到 S 后面形成的字符串。

对于字符串 S 和正整数 k,定义 S^k=\underbrace{S+S+...+S}_{k 个}

对于字符串 S 和正整数 a,定义 S_aS 中第 a 个字符。

对于字符串 S 和正整数 l,r,定义 S[l:r]=S_l+S_{l+1}+...+S_{r-1}+S_r

对于字符串 S 和正整数 a,定义 pre_{S,a}=S[1,a]suf_{S,a}=S[|S|-a+1:|S|]

对于字符串 ST,定义 ST 本质相同当且仅当 |S|=|T|\forall i \in[1,|S|],S_i=T_i

区间串复制 k 遍跨过边界的本质不同子串个数。

给定一个仅包含小写字母的字符串 S,每次询问给定 l,r,k,求出在字符串 (S[l:r])^k 中,有多少本质不同的子串 T,满足其至少跨过了一个 S[l:r] 的边界。

形式化的,求出在 (S[l:r])^k 中有多少本质不同的 T=(S[l:r])^k[a:b] 满足 \exist\ i \in [a,b),i \equiv 0\ (\bmod \ r-l+1)

好的,那么现在你大概理解了 "跨过边界" 是什么意思,然而它的出现让这个问题看上去很不优美,如果能把 "跨过边界" 去掉多好。

我也是这么想的,然而我太菜了去不掉,如果有人看完这篇文章后有解决 "区间串复制 k 遍本质不同子串个数" 的 idea 可以和我交流一下 qwq。

border theory

这里简单讲解一下前置知识 \text{border theory} 的一部分,如果掌握了可以跳过。

字符串 S 的一个 \text{border} \ B 需要满足 pre_{S,|B|}=suf_{S,|B|}=B

那么显然有对于 S\text{border} \ B,如果 CB\text{border},那么 CS\text{border}

那么我们在尝试求出 S 的所有 \text{border} 时,可以每次尝试求出 S 的最长 \text{border} \ B,然后令 S \leftarrow B,递归处理。

然而 border 的个数是 O(n) 的,这样做自然是复杂度爆炸了。

尝试找一些性质,发现当 2 \times |B| \ge |S| 时,S 中出现了长度为 |S|-|B| 的循环周期,我们称其为 C=pre_{S,|S|-|B|},同时我们记 L=|S| \bmod |C|

可以尝试在纸上画一下,然后就会发现 k \in \mathbb{Z},pre_{S,L+k \times |C|}S\text{border}

也就是说,通过找到 S 的最长 border B,我们可以依此得到很多个 \text{border},且这些 \text{border} 的长度是等差数列。

然后我们再递归处理 pre_{S,L+|C|},可以证明字符串 S 中除了通过上述方法得到的长度为等差数列的 border 集合之外,不存在 border D 满足 |D| \ge \lfloor \frac{L+|C|}{2} \rfloor,因为如果存在这样的 D,就会存在更小的循环周期,也说明存在比 B 更长的 border。这里可以手玩性感理解一下。

那么我们就得到了一个子问题 S',且对于 S' 的最长 \text{border} \ B,可以证明 2 \times |B| < |S|,所以递归次数是 O(\log |S|) 的。

那么我们可以通过这种方式得到字符串 S 的所有 \text{border},这同时也说明了字符串 S 的所有 \text{border} 可以用 O(\log |S|) 个等差数列表示出来。

可以用 SA 做到 O(\log n) 查询区间串最长 \text{border}

会这么多就够用了。

amiya

根据传统,发明一个算法的人有资格命名这个算法,所以我将这种 "在线求区间串跨过边界的本质不同子串个数" 的算法命名为阿米娅。qwq

提到本质不同子串,第一反应应该是 \overset{SAM}{\text{火萤IV型战略强袭装甲}},然而 "区间串复制 k 遍" 让流萤无计可施(大概吧,说不定有高论)。

那么我们尝试用某种方法刻画跨过边界本质相同子串。

对于一次询问 (l,r,k),令 T=(S[l:r])^2,L=r-l+1

先考虑 k=2 的情况。

记 $dis=l2-l1$。 那么有 $T[L-dis+1:L]=T[L+1:L+dis]$,$T[l_1:L-dis]=T[l_2:L]$,$T[L+1:r_1]=T[L+dis+1:r_2]$。 发现这是 $T$ 的一个长为 $dis$ 的 $\text{border}$。 既然如此,我们直接让本质相同的子串在某个 $\text{border}$ 处容斥掉即可,对于字符串的 $\text{border}$ 形态进行比较复杂的分类讨论即可。 具体的,我们求出 $\text{border}$ 的 $O(\log n)$ 个等差数列。对于每一个 $\text{border}$ 的等差数列和其对应的 $S[l:r]$ 的 $\text{border}$,记当前处理的 border 长度为 $len$,最小循环周期为 $cyc=len-border,rem=len \bmod cyc$,记下一个等差数列对应的前缀的长度为 $slen$、最小循环周期为 $scyc$,可以推出对答案的贡献为: - $2 \times border \ge len$: - $len \equiv 0\ (\bmod cyc)$:暂时不做贡献,在上一个等差数列对应的前缀处做贡献。 - $len \not\equiv 0\ (\bmod cyc)$: - 令 $num=\lfloor\frac{len}{cyc}\rfloor-1-[slen \equiv 0\ (\bmod scyc)]$。 - 贡献为 $- num \times \operatorname{LCP}(S[l:l+len-1],S[l+rem:l+len-1]) \times \operatorname{LCS}(S[l:l+len-1],S[l,l+len-1-rem])
- 即当前处理的等差数列的 $\text{border}$ 个数乘上其可能对应的 $LCP$ 和 $LCS$ 个数,由于形成了循环周期,所以在这个等差数列中不同 $\text{border}$ 对应的 $LCS$ 和 $LCP$ 是相等的,求出其中一个即可。
- 其中长度为 $rem$ 的 $\text{border}$ 会在递归到下一个等差数列对应的前缀内处理,所以令 $num \leftarrow num-1$。
- 如果下一个等差数列对应的前缀长度是最小周期的倍数,那么不在这里计算,此时令 $num \leftarrow num-1$。

好的,那么我们就做完了 k=2 的情况,好耶!

而当 k > 2 时,后面的 S[l:r] 总共会做 +(k-2) \times L \times cyc 的贡献,其中 cycS[l:r] 的整循环周期,且若 S[l:r] 的最小周期长度不能被 L 整除,则 cyc=L

没错!当 k>2 时答案关于 k 构成等差数列!

考虑证明,对于一次询问 l,r,k,令 T=(S[l:r])^k,L=r-l+1。本质不同子串的集合为 \{T_1,T_2,...T_{ans}\},子串 T_iT 中的所有出现位置的左端点集合为 A_i

容易发现 \forall i \in [1,ans],\exist a \in A_i,a \in [1,L]

所以每加入一个 S[l:r],我们只需要考虑左端点在 [1,L] 中的子串即可,显然从 k=2 开始每次增加的本质不同子串个数是相等的。

至此,问题解决,总复杂度 O(|S| \log |S| + q \log |S|)

好耶。

关于 “区间串复制 k 遍的本质不同子串个数” 的做法,首先有区间本质不同子串个数,拼上阿米娅之后就只需要处理同时在区间串和跨过边界处出现的子串了,这部分可以转化为一个二维数点问题,然而点数是 O(n) 的(可能有效点数比较少?还没有想过这方面),所以不太好做。

code

这里搬的是同学 @Aaronwrq 的代码,因为我的代码太丑了。

sto @Aaronwrq orz

#include <bits/stdc++.h>
#define MAXN 200010
using namespace std;

int n, q;
string S, T;
int lg[MAXN];
const long long mod = 1e9 + 7;

int val[20][MAXN];
set<int> ps[20][MAXN];

struct SA{
    string st;
    int oldrk[MAXN << 1], rk[MAXN << 1], id[MAXN], lin[MAXN], sa[MAXN], cnt[MAXN];
    int ht[MAXN], s[20][MAXN];
    void Build(bool flag) {
        int V = 127;
        for (int i = 1; i <= n; ++i) ++cnt[rk[i] = st[i]];
        for (int i = 1; i <= V; ++i) cnt[i] += cnt[i - 1];
        for (int i = n; i; --i) sa[cnt[rk[i]]--] = i;
        if (flag) for (int i = 1; i <= n; ++i) val[0][i] = rk[i], ps[0][rk[i]].insert(i);
        memcpy(oldrk + 1, rk + 1, sizeof(int) * n);
        V = 0;
        for (int i = 1; i <= n; ++i) {
            if (oldrk[sa[i]] == oldrk[sa[i - 1]]) rk[sa[i]] = V;
            else rk[sa[i]] = ++V;
        }
        for (int w = 1, k = 1; w < n; w <<= 1, ++k) {
            V = 0;
            for (int i = n; i > n - w; --i) id[++V] = i;
            for (int i = 1; i <= n; ++i) if (sa[i] > w) id[++V] = sa[i] - w;
            memset(cnt, 0, sizeof(int) * (V + 1));
            for (int i = 1; i <= n; ++i) ++cnt[lin[i] = rk[id[i]]];
            for (int i = 1; i <= V; ++i) cnt[i] += cnt[i - 1];
            for (int i = n; i; --i) sa[cnt[lin[i]]--] = id[i];
            memcpy(oldrk + 1, rk + 1, sizeof(int) * n);
            V = 0;
            for (int i = 1; i <= n; ++i) {
                if (oldrk[sa[i]] == oldrk[sa[i - 1]] && oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) rk[sa[i]] = V;
                else rk[sa[i]] = ++V;
            }
            if (flag) for (int i = 1; i <= n; ++i) val[k][i] = rk[i], ps[k][rk[i]].insert(i);
        }
        for (int i = 1, k = 0; i <= n; ++i) if (rk[i]) {
            if (k) --k;
            while (st[i + k] == st[sa[rk[i] - 1] + k]) ++k;
            ht[rk[i]] = k;
        }
        for (int i = 1; i <= n; ++i) s[0][i] = ht[i];
        for (int i = 1; i <= lg[n]; ++i) for (int j = 1; j <= n; ++j) s[i][j] = min(s[i - 1][j], s[i - 1][min(j + (1 << (i - 1)), n)]);
        return;
    }
    int Query(int l, int r) {
        if (l == r) return n - l + 1;
        l = rk[l], r = rk[r];
        if (l > r) swap(l, r);
        ++l;
        int w = lg[r - l + 1];
        return min(s[w][l], s[w][r - (1 << w) + 1]);
    }
}sa1, sa2;

int check(int l, int r, int k) {
    int wl = val[k][l], wr = val[k][r - (1 << k) + 1];
    int lst, led, ld, rst, red, rd;

    auto it = ps[k][wl].upper_bound(r - (1 << k) + 1);
    int lim = max(l + 1, r - (1 << (k + 1)));
    if (it == ps[k][wl].begin() || *--it < lim) return 0;
    lst = r - (1 << k) + 1 - *it;
    if (it == ps[k][wl].begin() || *--it < lim) ld = 1;
    else ld = r - (1 << k) + 1 - *it - lst;
    it = ps[k][wl].lower_bound(lim);
    led = r - (1 << k) + 1 - *it;

    it = ps[k][wr].lower_bound(l);
    lim = min(r - (1 << k), l + (1 << k));
    if (it == ps[k][wr].end() || *it > lim) return 0;
    rst = *it - l;
    if (++it == ps[k][wr].end() || *it > lim) rd = 1;
    else rd = *it - l - rst;
    it = ps[k][wr].upper_bound(lim);
    red = *--it - l;

    if (ld == rd) {
        if (lst > red || rst > led || (rst - lst) % ld) return 0;
        return min(led, red) + (1 << k); 
    }
    if ((led - lst) / ld > (red - rst) / rd) {
        swap(lst, rst), swap(led, red), swap(ld, rd);
    }
    for (int i = led; i >= lst; i -= ld) if (i >= rst && i <= red && !((i - rst) % rd)) return i + (1 << k);
    return 0;
}

int maxBorder(int l, int r) {
    int d = r - l, k = 0, bd = 0;
    while (d) d >>= 1, ++k;
    while (k >= 0 && !bd) bd = check(l, r, k), --k;
    return bd;
}

int lcp(int l, int r) {return sa1.Query(l, r);}
int lcs(int l, int r) {return sa2.Query(n - l + 1, n - r + 1);}

long long ans;

int Solve(int l, int r) {
    int len = r - l + 1, bd = maxBorder(l, r), cyc = len - bd;
    int slen = 0, scyc = 0;
    if (!bd) return n + 1;
    if ((bd << 1) >= len) {
        int rem = len % cyc;
        slen = cyc + rem, scyc = Solve(l, l + slen - 1);
        if (rem) {
            long long ld = min(lcs(r, r - rem), len - rem);
            long long rd = min(lcp(l, l + rem), len - rem);
            ans -= (len / cyc - 1 - (!(slen % scyc))) * ld * rd;
        }
    }
    else {
        slen = bd, scyc = Solve(l, l + slen - 1);
        if (slen % scyc) {
            long long ld = min(lcs(r, r - bd), len - bd);
            long long rd = min(lcp(l, l + bd), len - bd);
            ans -= ld * rd;
        }
    }
    if (!(slen % scyc)) {
        long long ld = min(lcs(r, r - scyc) + scyc, len - scyc);
        long long rd = min(lcp(l, l + scyc) + scyc, len - scyc);
        ans -= ld * rd - (ld + rd - scyc) * scyc;
    }
    return cyc;
}

int main()
{
    freopen("mas.in", "r", stdin);
    freopen("mas.out", "w", stdout);
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    cin >> n >> S >> q, T = S;
    reverse(T.begin(), T.end());
    S = '#' + S, T = '#' + T;
    sa1.st = S, sa2.st = T;
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
    sa1.Build(1), sa2.Build(0);
    while (q--) {
        int l, r, k; cin >> l >> r >> k;
        if (k <= 1) {cout << "0\n"; continue;}
        long long len = r - l + 1;
        ans = len * len;
        long long cyc = Solve(l, r);
        if (len % cyc) cyc = len;
        else ans -= (len - cyc) * (len - cyc);
        ans = (ans + (k - 2) * len % mod * cyc) % mod;
        cout << ans << "\n";
    }
    return 0;
}