CF1780G Delicious Dessert 题解

· · 题解

题意是求 S 中子串出现次数恰好是其长度倍数时的这些子串的总数量

虽然是 SAM 模版题,但是用 SA 也可以做,做法和 P3804 类似,核心是求出长度为 i 的所有子串,都各自出现了多少次。

先求出 \operatorname{height}_i,然后按其值从大到小排序,每次取出对应的 \operatorname{sa}_i\operatorname{sa}_{i-1},利用并查集合并,同时维护每个连通块大小,以及大小为 k 的连通块数量 c_k

从大到小枚举值到 i 时,把对应连通块维护好之后,此时某个连通块的大小 k 就说明这个长度为 i 的子串是 k 个后缀的公共前缀,也即该子串恰好出现了 k 次。如果不理解,就想想根据 \operatorname{height} 数组的性质: \operatorname{sa}_i\sim \operatorname{sa}_j 的 LCP 为 \min _{k=i+1}^j \operatorname{height}_k,或者用后缀树理解。

然后枚举 i 的倍数 j 作为出现次数,j\times c_j 就是对答案的贡献。

时间复杂度 O(n\log n),代码如下:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const LL mod = 1e9 + 7;
const int N = 1000005;

int sa[N];
int rk[N << 1], tk[N << 1];
int height[N];
int id[N], cnt[N];
void SA(char *s, int n, int m) {
    for (int i = 1; i <= n; i++) sa[i] = i, rk[i] = s[i];

    for (int i = 0; i <= m; i++) cnt[i] = 0;
    for (int i = 1; i <= n; i++) ++cnt[rk[i] = s[i]];
    for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
    for (int i = n; i >= 1; i--) sa[cnt[rk[i]]--] = i;

    for (int w = 1; w < n; w *= 2) {
        int p = 0;
        for (int i = n - w + 1; i <= n; i++) id[++p] = i;
        for (int i = 1; i <= n; i++) {
            if (sa[i] > w) id[++p] = sa[i] - w;
        }
        for (int i = 1; i <= m; i++) cnt[i] = 0;
        for (int i = 1; i <= n; i++) ++cnt[rk[i]];
        for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; i--) sa[cnt[rk[id[i]]]--] = id[i];

        m = 0;
        for (int i = 1; i <= n; i++) {
            if (rk[sa[i]] == rk[sa[i - 1]] && rk[sa[i] + w] == rk[sa[i - 1] + w]) {
                tk[sa[i]] = m;
            } else {
                tk[sa[i]] = ++m;
            }
        }
        for (int i = 1; i <= n; i++) rk[i] = tk[i];
    }

    for (int i = 1, k = 0; i <= n; i++) {
        if (rk[i] == 0) continue;
        if (k) --k;
        while (s[i + k] == s[sa[rk[i] - 1] + k]) ++k;
        height[rk[i]] = k;
    }
}
char s[N];
vector<int> a[N];
int sz[N], p[N], c[N];
int find(int x) {
    return p[x] == x ? x : p[x] = find(p[x]);
}
int main() {
    int n;
    scanf("%d%s", &n, s + 1);
    SA(s, n, 128);
    for (int i = 1; i <= n; i++) {
        p[i] = i, sz[i] = 1;
        if (i > 1)
            a[height[i]].push_back(i);
    }
    c[1] = n;
    LL ans = 0;
    for (int i = n; i >= 1; i--) {
        for (auto x : a[i]) {
            int y = find(x - 1);
            x = find(x);
            if (x != y) {
                c[sz[x]]--;
                c[sz[y]]--;
                p[x] = y;
                sz[y] += sz[x];
                c[sz[y]]++;
            }
        }
        for (int j = i; j <= n; j += i) {
            ans += (LL)j * c[j];
        }
    }
    printf("%lld\n", ans);
    return 0;
}