题解:P10581 [蓝桥杯 2024 国 A] 重复的串
思路
看到字符串匹配,方案数关键字,自然想到 dp 与 kmp 算法的结合,看到
一般地可以想到方程
dp[0][0][0] = 1;
for (int i = 0; i < n; i++)
for (int k = 0; k <= 2; k++)
for (int j = 0; j < m; j++) //枚举每个状态
for (char ch = 'a'; ch <= 'z'; ch++) { //枚举当前状态填的字母
int st = j;
// nex 数组预处理过程省略
while (st && str[st + 1] != ch) st = nex[st];
if (str[st + 1] == ch) st++;
if (st == m) { //如果完全匹配成功可转移到 k + 1
dp[i + 1][nex[st]][k + 1] = (dp[i + 1][nex[st]][k + 1] + dp[i][j][k]) % MOD;
}
else dp[i + 1][st][k] = (dp[i + 1][st][k] + dp[i][j][k]) % MOD;
}
ans = 所有j的dp[n][j][2]之和;
上面的方程可以通过一部分的数据,但我们可以做的更好。
通过观察发现,对于每个 aba 在
代码
#include <iostream>
#include <string.h>
using namespace std;
#define MAX_N 105
#define ll long long
const ll MOD = 998244353LL;
char str[MAX_N];
int nex[MAX_N] = { 0 };
ll arr[MAX_N][MAX_N] = { 0 };
ll temp[MAX_N][MAX_N] = { 0 };
ll ret[MAX_N][MAX_N] = { 0 };
//矩阵乘法
void multi(ll a[MAX_N][MAX_N], ll b[MAX_N][MAX_N], int n, int m, int k) {
memset(temp, 0, sizeof temp);
for (int i = 0; i < n; i++)
for (int c = 0; c < k; c++)
for (int r = 0; r < m; r++)
temp[i][c] = (temp[i][c] + a[i][r] * b[r][c] % MOD) % MOD;
memcpy(a, temp, sizeof temp);
}
//矩阵快速幂
void quick_mi(ll b, int n) {
while (b) {
if (b % 2) multi(ret, arr, n, n, n);
multi(arr, arr, n, n, n);
b /= 2;
}
}
int n, m;
void get_nex() {
nex[0] = nex[1] = 0;
for (int i = 2, j; i <= m; i++) {
j = i - 1;
while (j && str[nex[j] + 1] != str[i]) j = nex[j];
if (j) nex[i] = nex[j] + 1;
else nex[i] = 0;
}
}
//给每个状态分配一个下标
int get_ind(int j, int k) {
return j + k * (m + 1);
}
ll matrix[MAX_N][MAX_N] = { 0 };
int main() {
scanf("%s", str + 1);
scanf("%d", &n);
m = strlen(str + 1);
get_nex();
//初始化单位矩阵
for (int x = 0; x < (m + 1) * 3; x++)
ret[x][x] = 1;
for (int k = 0; k <= 2; k++)
for (int j = 0; j < m; j++)
for (char ch = 'a'; ch <= 'z'; ch++) {
int st = j;
while (st && str[st + 1] != ch) st = nex[st];
if (str[st + 1] == ch) st++;
if (st == m) {
if (k != 2) arr[get_ind(j, k)][get_ind(nex[st], k + 1)]++;
}
else arr[get_ind(j, k)][get_ind(st, k)]++;
}
matrix[0][get_ind(0, 0)] = 1;
quick_mi(n, (m + 1) * 3);
multi(matrix, ret, 1, (m + 1) * 3, (m + 1) * 3);
ll ans = 0;
for (int j = 0; j <= m; j++)
ans = (ans + matrix[0][get_ind(j, 2)]) % MOD;
printf("%lld", ans);
return 0;
}