题解:P10375 [AHOI2024 初中组] 计数

· · 题解

分析

通过对题目的分析,不难想到,如果一个串的首尾完全一样,那么一定可以消除。形如:

a...a         a....b...c...a

当一个串首尾不同,但是可以进行多组分块,也可以完成消除。形如:

a...ab....bc....cz...z

对于一个不能完成消除的串来说,我们可以选择增补数字使得它变成可以消除的串。形如:

a...ab...bc...

只需在末尾增补 abc 就可以使得上述情况变成可以消除的串。

所以,我们设 f_{i,j,k} 表示前 i 个元素,分成 j 组,可以(k = 1)或不可以(k = 0)完成消除的方案数。

其中,我们认为形如 a...ab 这样的串为 2 组,虽然 b 不能完成匹配。

边界

考虑 dp 边界为只有 1 个元素,分组为 1,一定是不能完成消除的情况,该情况共有 m 种,分别是:

1, 2, ..., m

f_{1,1,0} = m

答案

考虑答案为使用 n 个元素,不确定被分为几组,但是可以完成消除的方案数。

即输出为 ans = \sum \limits_{i = 1}^{n} f_{n,i,1}

转移方程

考虑转移方程:

  1. 一组可以完成消除的串 a...a,增补上先前没有出现的元素,可以使得组数多 1,变为 a...ab,共计 m - j 种情况。

  2. 不管原本是否是可以消除的串,在末尾增补上分组中含有的元素,可以使得它变为一个可以消除的串。a...ab... 在末尾新增 aba...ab...bc...c 在末尾新增 abc 可以使得串变为可以消除,共计 j 种情况。

  3. 原本不能完成消除的串,增补上分组中未出现的元素,仍然使得它无法完成消除。a...ab...,增补上 c,共计 m - j 种情况。

所以我们可以写出转移方程为:

\begin{cases} f_{i,j+1,0}=f_{i-1,j,1}\times (m-j) \\ f_{i,j,1}=(f_{i-1,j,1} + f_{i-1,j,0}) \times j\\ f_{i,j,0}=f_{i-1,j,0}\times (m-j) \end{cases}

代码块

其中,使用 i 个元素时,最多被分为 i 组,所以第二层循环应该是 j \le i。我们可以写出代码:

#include <bits/stdc++.h>
using namespace std;
int n, m, mod = 1e9 + 7;
long long f[3001][3001][2], ans;
int main() {
    cin >> n >> m;
    f[1][1][0] = m;
    for (int i = 2; i <= n; i++) {
        for (int j = 1; j <= i; j++) {
            f[i][j + 1][0] = (f[i][j + 1][0] + f[i - 1][j][1] * (m - j)) % mod;
            f[i][j][1] = ((f[i][j][1] + f[i - 1][j][0] + f[i - 1][j][1]) % mod * j) % mod;
            f[i][j][0] = (f[i][j][0] + f[i - 1][j][0] * (m - j)) % mod;
        } 
    }
    for (int i = 1; i <= n; i++) ans = (ans + f[n][i][1]) % mod;
    cout << ans;
    return 0;
}

时间复杂度:O(n^2)

空间复杂度:O(n^2)

优化

我们发现 f_i 仅与 f_{i-1} 有关,可以将原数组进行滚动操作,又发现 f_{j+1}f_j 有关,所以顺序循环会导致当前这一层的值发生重复计算,所以对第二层循环可以采取逆序操作(例如0-1背包的转移方程优化)。

#include <bits/stdc++.h>
using namespace std;
int n, m, mod = 1e9 + 7;
long long ans, f[3001][2];
int main() {
    cin >> n >> m;
    f[1][0] = m;
    for (int i = 2; i <= n; i++) {
        for (int j = i; j > 0; j--) {
            f[j + 1][0] = (f[j + 1][0] + f[j][1] * (m - j)) % mod;
            f[j][1] = ((f[j][1] + f[j][0]) * j) % mod;
            f[j][0] = (f[j][0] * (m - j)) % mod;
        } 
    }
    for (int i = 1; i <= n; i++) ans = (ans + f[i][1]) % mod;
    cout << ans;
    return 0;
}

时间复杂度:O(n^2)

空间复杂度:O(n)