题解:P10581 [蓝桥杯 2024 国 A] 重复的串

· · 题解

思路

看到字符串匹配,方案数关键字,自然想到 dp 与 kmp 算法的结合,看到 n 的数量比较大,不用慌,先把朴素的方程想出来再优化。

一般地可以想到方程 dp[i][j][k] 为字符串序列长度为 i 并且当前匹配到模式串的长度为 j ,总匹配次数恰好为 k 次的方案数。转移的方式就是,我们枚举每个 i, j, k 再枚举当前状态下填的字母,通过 kmp 进行匹配,计算出下一个可转移的状态。代码:

dp[0][0][0] = 1;
for (int i = 0; i < n; i++)
    for (int k = 0; k <= 2; k++)
        for (int j = 0; j < m; j++) //枚举每个状态
            for (char ch = 'a'; ch <= 'z'; ch++) { //枚举当前状态填的字母
                int st = j;
        // nex 数组预处理过程省略
                while (st && str[st + 1] != ch) st = nex[st];
                if (str[st + 1] == ch) st++;
                if (st == m) { //如果完全匹配成功可转移到 k + 1
                    dp[i + 1][nex[st]][k + 1] = (dp[i + 1][nex[st]][k + 1] + dp[i][j][k]) % MOD;
                }
                else dp[i + 1][st][k] = (dp[i + 1][st][k] + dp[i][j][k]) % MOD;
            }
ans = 所有j的dp[n][j][2]之和;

上面的方程可以通过一部分的数据,但我们可以做的更好。

通过观察发现,对于每个 j 填的下一个字母,它可以转移的状态是固定的!比如样例中的 abaj = 0 的情况下,不管 i, k 是多少,只要填字母 a 就最终总是会到达状态 j = 1;只要填字母 b 总是会到状态 j = 0。于是,我们就考虑使用矩阵加速递推,关于矩阵加速递推可以参考 OI-WIKI。

代码

#include <iostream>
#include <string.h>
using namespace std;
#define MAX_N 105
#define ll long long
const ll MOD = 998244353LL;
char str[MAX_N];
int nex[MAX_N] = { 0 };
ll arr[MAX_N][MAX_N] = { 0 };
ll temp[MAX_N][MAX_N] = { 0 };
ll ret[MAX_N][MAX_N] = { 0 };
//矩阵乘法
void multi(ll a[MAX_N][MAX_N], ll b[MAX_N][MAX_N], int n, int m, int k) {
    memset(temp, 0, sizeof temp);
    for (int i = 0; i < n; i++)
        for (int c = 0; c < k; c++)
            for (int r = 0; r < m; r++)
                temp[i][c] = (temp[i][c] + a[i][r] * b[r][c] % MOD) % MOD;
    memcpy(a, temp, sizeof temp);
}
//矩阵快速幂
void quick_mi(ll b, int n) {
    while (b) {
        if (b % 2) multi(ret, arr, n, n, n);
        multi(arr, arr, n, n, n);
        b /= 2;
    }
}
int n, m;
void get_nex() {
    nex[0] = nex[1] = 0;
    for (int i = 2, j; i <= m; i++) {
        j = i - 1;
        while (j && str[nex[j] + 1] != str[i]) j = nex[j];
        if (j) nex[i] = nex[j] + 1;
        else nex[i] = 0;
    }
}
//给每个状态分配一个下标
int get_ind(int j, int k) {
    return j + k * (m + 1);
}
ll matrix[MAX_N][MAX_N] = { 0 };
int main() {
    scanf("%s", str + 1);
    scanf("%d", &n);
    m = strlen(str + 1);
    get_nex();
    //初始化单位矩阵
    for (int x = 0; x < (m + 1) * 3; x++)
        ret[x][x] = 1;
    for (int k = 0; k <= 2; k++)
        for (int j = 0; j < m; j++)
            for (char ch = 'a'; ch <= 'z'; ch++) {
                int st = j;
                while (st && str[st + 1] != ch) st = nex[st];
                if (str[st + 1] == ch) st++;
                if (st == m) {
                    if (k != 2) arr[get_ind(j, k)][get_ind(nex[st], k + 1)]++;
                }
                else arr[get_ind(j, k)][get_ind(st, k)]++;
            }
    matrix[0][get_ind(0, 0)] = 1;
    quick_mi(n, (m + 1) * 3);
    multi(matrix, ret, 1, (m + 1) * 3, (m + 1) * 3);
    ll ans = 0;
    for (int j = 0; j <= m; j++)
        ans = (ans + matrix[0][get_ind(j, 2)]) % MOD;
    printf("%lld", ans);
    return 0;
}