题解 AT693 【文字列】

· · 题解

Description

构造一个小写字母构成的字符串,使小写字母 i 出现了 freq_i 次,且相同字母不能相邻,求方案数并对 10^9 + 7 取模。

### Solution 计数 dp,先找出所有 $freq_i > 0$ 的小写字母,一个个插入字符串中。因为答案与插入顺序无关,所以我们按 `aabbbccddddeeefgg` 这样的顺序插入。在插入的过程中,难免会出现不合法的情况。举个栗子,如 `*a a*c*b*d d d*b b*`,不合法的位置有 $4$ 个,即为空格。合法位置有 $6$ 个,即为 `*`。可以发现三个性质: 1. $S$ 的合法位置数与不和法位置数之和为 $|S| + 1
  1. 没有两个合法位置或不合法位置是相邻的。
  2. 每一个合法位置可以插入任意个字母,也可以不插。

可以想到状态,令 f_{i,j} 为插入了所有的小写字母 i 后,有 j 个不合法位置的方案数。赋初值,将第一段 freq_i > 0 的小写字母全部插入后的 f_{i,j} = 1。如 bbbccddddeeefggbbb 是第一段,于是 f_{b - a = 1, 2} = 1

枚举 i,和 f_{i-1,k} 中的 k,意味着 f_{i,j} 会从 f_{i-1,k} 转移。考虑将所有的小写字母 i 插入到字符串中,无非是插入到合法的位置或不合法的位置。所以再枚举一个 goodbad,如字面意思,good 为插入到合法位置的段数bad 为插入到不合法位置的段数。在枚举的过程中,同时维护一个在加入字母前的字符串长度 sum,用于确定枚举范围。

考虑一个问题:将所有小写字母 i 加入字符串,有 good 段字母插入了合法位置,有 bad 段字母插入了不合法位置。那么新产生了多少不合法的位置,又减少了多少不合法的位置?

减少的就是 good 个。因为插入的字母分成了 good + bad 段,那么新产生了 good + bad - 1 个间隔。如果 good+bad=1,那么成为不合法位置的位置一共有 freq_i - 1 个,所以增加了 freq_i - 1 - good - bad + 1 = freq_i - good - bad 个不合法位置。

所以,j = k - bad + freq_i - good - bad

转移方程如下,其中 sum + 1 - k 为根据性质一求出的合法位置个数,n \choose m 为组合数可以预处理。可以发现,转移方程是根据乘法原理将每一部分相乘,其中的 {freq_i - good - bad} \choose {f_i - 1} 是什么呢?因为我们是把一段段相同的字母插入合法的位置或不合法的位置,不仅要乘上几段字母插入到哪几个位置的方案数,还要乘上所有的字母 i 是如何分成了几段字母的方案数。把一段字母看成一个盒子,每个字母看成一个球,一共 good + bad 个盒子,freq_i 个球,问题转换成求每个盒子至少一个球的方案数。套上插板法公式即可。

f_{i,j} = f_{i,j} + f_{i - 1,k} \times {{bad}\choose{k}} \times {{good}\choose{sum + 1 - k}} \times {{good + bad -1} \choose{freq_i - 1}} ans = f_{n-1,0}

时间复杂度 O(n^3),其中 n = \sum_{i=0}^{25} freq_i

Code

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 500 + 5, p = 1e9 + 7;
inline int read() {
    int X = 0,w = 0; char ch = 0;
    while(!isdigit(ch)) {w |= ch == '-';ch = getchar();}
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
    return w ? -X : X;
}
int f[26][N], C[N][N], freq[N], sum, fir, n;
signed main() {
    for (int i = 0; i < 26; i++) {
        freq[n] = read();
        if (freq[n]) n++;
    }
    for (int i = 0; i <= 300; i++) {
        C[i][0] = C[i][i] = 1;
        for (int j = 1; j < i; j++) C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % p;
    }
    f[0][freq[0] - 1] = 1, sum = freq[0];
    for (int i = 1; i < n; i++) {
        for (int k = 0; k <= sum; k++) 
            for (int bad = 0; bad <= min(freq[i], k); bad++)
                for (int good = 0; good + bad <= freq[i] && good <= sum + 1 - k; good++) {
                    int j = k - bad + freq[i] - good - bad;
                    f[i][j] = (f[i][j] + f[i - 1][k] * C[k][bad] % p * C[sum + 1 - k][good] % p 
                    * C[freq[i] - 1][good + bad - 1] % p) % p;
                }
        sum += freq[i]; 
    }   
    printf("%lld\n", f[n - 1][0]);
    return 0;
}