题解:P16448 [XJTUPC 2026] Triple Mirror: The Harmony of Repetition

· · 题解

考虑拆贡献。

对每个点 i,计算其作为翻转字符串的那段的开头的方案数会舒服一些。若子串长度为 3k,则对于每个 0\le j\le k-1,我们需要满足 s_{i-2-2j}=s_{i-1-2j}=s_{i+j}。由于若长度为 3k 的子串是合法的,则长度为 3(k-1) 也必定合法,所以只需要计算合法的最大的 k,也就是满足上述式子的最长前缀 j

将条件拆分成 s_{i-2-2j}=s_{i-1-2j}s_{i-2-2j}=s_{i+j},分别求出满足条件的最大 j 然后取最小值即可。对于第一个条件,容易用 dp 求出;对于第二个条件,将奇偶位置上的字符分别取出来并取反,求 LCP 即可。

::::info[Code]

#include <bits/stdc++.h>
using namespace std;

const int N = 600000 + 10;

struct SA {
    int a[N], b[N], sa[N], h[N], rk[N], t[N];
    string s;
    int f[N][20];
    int n, m;

    void getsa() {
        memset(a, 0, sizeof(a));
        memset(b, 0, sizeof(b));
        memset(t, 0, sizeof(t));
        memset(sa, 0, sizeof(sa));
        memset(rk, 0, sizeof(rk));
        for (int i = 1; i <= n; i++) {
            a[i] = s[i];
            ++t[a[i]];
        }
        for (int i = 2; i <= 128; i++)
            t[i] += t[i - 1];
        for (int i = n; i >= 1; i--)
            sa[t[a[i]]--] = i;
        int now = 128;
        for (int k = 1; k <= n; k *= 2) {
            int cnt = 0;
            for (int i = n - k + 1; i <= n; i++)
                b[++cnt] = i;
            for (int i = 1; i <= n; i++)
                if (sa[i] > k)
                    b[++cnt] = sa[i] - k;
            memset(t, 0, sizeof(t));
            for (int i = 1; i <= n; i++)
                t[a[i]]++;
            for (int i = 2; i <= now; i++)
                t[i] += t[i - 1];
            for (int i = n; i >= 1; i--)
                sa[t[a[b[i]]]--] = b[i], b[i] = 0;
            swap(a, b);
            int tot = 1;
            a[sa[1]] = 1;
            for (int i = 2; i <= n; i++) {
                if (b[sa[i]] == b[sa[i - 1]] && b[sa[i] + k] == b[sa[i - 1] + k])
                    a[sa[i]] = tot;
                else
                    a[sa[i]] = ++tot;
            }
            if (tot == n)
                break;
            now = tot;
        }
    }

    void gethi() {
        memset(rk, 0, sizeof(rk));
        memset(h, 0, sizeof(h));
        for (int i = 1; i <= n; i++)
            rk[sa[i]] = i;
        int now = 0;
        for (int i = 1; i <= n; i++) {
            if (rk[i] == 1)
                continue;
            if (now)
                now--;
            int j = sa[rk[i] - 1];
            while (i + now <= n && j + now <= n && s[i + now] == s[j + now])
                now++;
            h[rk[i]] = now;
        }
    }
    int ask(int l, int r) {
        if (l == r)
            return n - l + 1;
        l = rk[l], r = rk[r];
        if (l > r)
            swap(l, r);
        l++;
        int k = log2(r - l + 1);
        return min(f[l][k], f[r - (1 << k) + 1][k]);
    }
    void st() {
        memset(f, 0, sizeof(f));
        for (int i = 1; i <= n; i++)
            f[i][0] = h[i];
        for (int i = 1; i <= 19; i++)
            for (int j = 1; j + (1 << i) - 1 <= n; j++)
                f[j][i] = min(f[j][i - 1], f[j + (1 << (i - 1))][i - 1]);
    }

    void bu(string ss) {
        s = " " + ss;
        n = ss.size();
        getsa();
        gethi();
        st();
    }
} sa;
int f[200005];

int main() {
    string s;
    cin >> s;
    int n = s.size();
    string a, b;
    for (int i = 0; i < n; i += 2)
        a += s[i];
    for (int i = 1; i < n; i += 2)
        b += s[i];
    reverse(a.begin(), a.end());
    reverse(b.begin(), b.end());
    int l1 = a.size(), l2 = b.size();
    string aa = s;
    aa += '#';
    int n1 = aa.size() + 1;
    aa += a;
    aa += '$';
    int n2 = aa.size() + 1;
    aa += b;
    sa.bu(aa);
    for (int i = 2; i <= n; i++)
        if (s[i - 2] == s[i - 1]) {
            f[i] = 1;
            if (i >= 4)
                f[i] += f[i - 2];
        }
    long long ans = 0;
    for (int i = 2; i < n; i++) {
        if (s[i] == 0)
            continue;
        int j = i & 1;
        int len, qi;
        if (j == 0)
            len = l1, qi = n1;
        else
            len = l2, qi = n2;
        int id = (i - j - 2) / 2;
        int q = len - id - 1;
        int lc = sa.ask(i + 1, qi + q);
        ans += min(lc, min(n - i, f[i]));
    }
    cout << ans;
}

::::