题解:P12124 [蓝桥杯 2024 省 B 第二场] 前缀总分

· · 题解

下文中记 C = 26,表示字符种类的数量。

暴力解法 O(Cn^5)

枚举将第 i 个字符串的第 j 个字符改为 c 的所有方案,时间复杂度 O(Cn^2),修改并计算总分,O(n^3)

暴力优化 O(Cn^3\log n)

我们可以使用字符串哈希来优化判断两个字符串是否相等。

另外,可以用二分来优化求两个字符串的最大前缀。

枚举所有方案的时间复杂度为 O(Cn^2),处理修改以及计算总分的复杂度为 O(n\log n)

再优化 O(Cn^3)

首先,我们依旧使用上述暴力解法中的枚举方式——所有将第 i 个字符串的第 j 个字符改为 k,时间复杂度 O(Cn^2)

接下来我们考虑,如果用不大于 O(n) 的时间去完成计算一个枚举的分数。

将第 i 个字符串的第 j 个字符改为 k 时,所影响答案的只有 P(s_1, s_i), P(s_2, s_i), P(s_3, s_i), \dots, P(s_n, s_i)

所以我们可以计算出未修改时的总得分的 tot,计算出未修改时第 i 个字符串对答案的贡献 g[i]。设修改之后第 i 个字符串对答案的贡献为 res,那么修改之后的答案即为 tot - g[i] + res

那么接下来,我们要尝试处理计算,将第 i 个字符串的第 j 个字符改为 k 之后,第 i 个字符串对答案的贡献。

那么显而易见,我们需要计算修改之后的第 i 字符串与剩下 n-1 个字符串的最大前缀。

设其中一个字符串为 s_u,计算修改之后的 s_i 与修改之前,只有第 j 个字符被改变,j 左侧的字符,以及右侧的字符均为改变。

那么我们可以尝试比较修改前的 s_is_u0 开始的最大前缀长度 left

上述分析中,我们多次需要用到第 i 个字符串与第 j 个字符串从 k 开始的最大前缀。

考虑动态规划:f[i][j][k] 表示第 i 个字符串与第 j 个字符串从 k 开始的最大前缀长度。

考虑动态转移:

由于计算 f[i][j][k] 时,需要用到 f[i][j][k + 1],故预处理 f 数组时需要倒序处理。

暴力优化 O(Cn^3\log n)

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>

using namespace std;

const int N = 2e2 + 10, P = 131;

typedef unsigned long long ULL;

int n;
string str[N];
ULL f[N][N], p[N];
int g[N];
int tot;

ULL query(int u, int l, int r)
{
    return f[u][r] - f[u][l - 1] * p[r - l + 1];
}

int calc(int u, bool flag)
{
    int res = 0;
    for (int i = 1; i <= n; ++ i )
        if (i != u)
        {
            int l = 1, r = min(str[u].size() - 1, str[i].size() - 1);
            while (l < r)
            {
                int mid = l + r + 1 >> 1;
                if (query(i, 1, mid) == query(u, 1, mid))
                    l = mid;
                else
                    r = mid - 1;
            }
            if (query(i, 1, l) == query(u, 1, l))
                res += l;
        }

    if (flag)
    {
        g[u] = res;
        tot += res;
    }

    return res;
}

int modify(int u, int k, int c)
{
    char t = str[u][k];
    str[u][k] = 'a' + c;

    for (int i = 1; i < str[u].size(); ++ i )
        f[u][i] = f[u][i - 1] * P + str[u][i];

    int res = tot - g[u] * 2 + calc(u, false) * 2;

    str[u][k] = t;

    for (int i = 1; i < str[u].size(); ++ i )
        f[u][i] = f[u][i - 1] * P + str[u][i];

    return res / 2;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    cin >> n;
    for (int i = 1; i <= n; ++ i )
    {
        cin >> str[i];
        str[i] = ' ' + str[i];
    }

    p[0] = 1;
    for (int i = 1; i < N; ++ i )
        p[i] = p[i - 1] * P;

    for (int i = 1; i <= n; ++ i )
        for (int j = 1; j < str[i].size(); ++ j )
            f[i][j] = f[i][j - 1] * P + str[i][j];

    for (int i = 1; i <= n; ++ i )
        calc(i, true);

    int res = 0;
    for (int i = 1; i <= n; ++ i )
        for (int j = 1; j < str[i].size(); ++ j )
            for (int k = 0; k < 26; ++ k )
                res = max(res, modify(i, j, k));

    cout << res << endl;

    return 0;
}

再优化 O(Cn^3)

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>

using namespace std;

const int N = 2e2 + 10;

int n;
string str[N];
int g[N];
int tot;
int f[N][N][N];     //  [i, j, k] 第i个字符串和 第j个字符串 k个字符起最大连续数量

void init()
{
    for (int i = 1; i <= n; ++ i )
        for (int j = i + 1; j <= n; ++ j )
        {
            int mn = min(str[i].size(), str[j].size());
            for (int k = mn - 1; k >= 0; -- k )
                if (str[i][k] == str[j][k])
                    f[i][j][k] = f[j][i][k] = f[i][j][k + 1] + 1;
        }

    for (int i = 1; i <= n; ++ i )
    {
        for (int j = 1; j <= n; ++ j )
            g[i] += f[i][j][0];
        tot += g[i];
    }

    tot /= 2;
}

int modify(int u, int k, int c)
{
    int res = 0;
    for (int i = 1; i <= n; ++ i )
        if (i != u)
        {
            int left = min(f[u][i][0], k);
            res += left;

            if (left == k && str[i].size() > k && str[i][k] - 'a' == c)
            {
                res ++;
                res += f[u][i][k + 1];
            }
        }

    return tot - g[u] + res;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    cin >> n;
    for (int i = 1; i <= n; ++ i )
        cin >> str[i];

    init();

    int res = 0;
    for (int i = 1; i <= n; ++ i )
        for (int j = 0; j < str[i].size(); ++ j )
            for (int k = 0; k < 26; ++ k )
                res = max(res, modify(i, j, k));

    cout << res << endl;

    return 0;
}