题解 P5417 【[CTSC2016]萨菲克斯·阿瑞】
鉴于这题已经过了三年且网上没有靠谱题解,以及笔者写本文时所有 luogu AC 代码均是本人的代码,现在这里将简单介绍一下思路,起些抛砖引玉的作用。
如果要详细的题解 (比如细节等),请左转 此处 以获得更好(cha)的阅读体验~
我们分 7 步来完成这道题。
---- Step 1 ----
首先考虑对于一个长度为
我们在
于是,
---- Step 2 ----
接下来我们就能推出,对于一个确定的后缀数组,怎样的串是满足条件的。
由 Step 1 的结论,所有
举个例子:后缀数组 5 3 1 4 2 形成的不等式链即为
---- Step 3 ----
我们先从最简单的情形着手,
为了防止计算重复,我们对于一个后缀数组
因此一个后缀数组的秩等于它在 Step 2 的不等式链中 "
我们需要统计字符集大小为
设字符串中有 a,b。因此将它们进行带重复元素的排列,可知共有
我们要从中去掉秩为 a,b 构成的字符串中,秩为
---- Step 4 ----
经历完
还是先统计秩为 a,b 以及 c。
由可重排列,由这些元素构成的秩不超过
当然,这些是秩不超过
回到 Part 2 中的不等式链。首先,由于这些 SA 的秩不超过
考察所有
而事实上,它的秩可以是
然而
不妨设 "
也就是说,此时,这个后缀数组即为满足
因此,每个 "由
同理,"
可以看出,这是一个容斥的过程。因此,我们还需要把秩为
---- Step 5 ----
然鹅题目不是让你统计秩为
那我们刚才搞那么多到底是为了啥?无非就是构造一个一一对应,将一个排列对应到一个串上去,使得计数不重不漏啊!
因此,对一个秩为
具体地,我们按字典序从小到大枚举每个字符出现了几次,然后算出有多少个 "满秩" 的后缀数组。(这里 "满秩" 的意思是,这个串的字符集大小恰好等于该后缀数组的秩,不出现字符集冗余的情况)
但是这里还有一个比较关键的问题。有的秩比较小的后缀数组,它的来源是一个较大的字符集,而它所对应的 "满秩" 的字符串是不符合题目要求的。
举个例子,你有
因此我们需要一点计数技巧,在这里可以采用 "合并" 这种操作。
考虑 (反正也是有借无还嘛),从而完成 "合并" 的过程。
但是,能借后面的串用的前提是,
当然,不仅是
---- Step 6 ----
最后就只剩下实现了,至于如何高效地完成这个容斥的过程,那就考虑用 DP 来求解容斥系数啦。
将
首先,分子上的
(下面就是 DP 状态定义啦)
用
要注意的一点是,注意,这个 DP 是基于贪心的,因此每个字符一定是能用则用,因此不使用的字符一定是一段后缀,即一旦一个字符未使用,就预示着这个串的 "终结"。
先扔掉
转移分三种情况:
-
第
i 个字符使用若干 (正整数) 个,然后将这一段切掉 (也是最正常的情况)。
设这种字符使用了l 个,则转移为f_{i + 1, j + l, 0} \uparrow f_{i, j, k} \cdot \frac 1 {(k + l)!} -
第
i 个字符准备和下一个字符进行容斥型合并,即它们在\left( 3 \right) 式中使用加号产生贡献。
由于是容斥型合并,每多一个 "+ " 号就要产生-1 的系数。故转移方程为f_{i, j + l, k + l} \uparrow - f_{i, j, k} -
第
i 个字符准备和下一个字符进行正常合并 (即统计秩比较小的后缀数组)。由 Step 5,此时要求第i 个字符必须用满。
因此转移系数还是+ 1 ,方程为
以上 a += b (in C++)。
至于统计答案,相当于统计这个串在哪个字符处 "终结",即答案等于
时间复杂度
---- Step 7 ----
最后一步相信大家都明白,对这个 DP 进行优化。
不难发现,当固定
因此,可以使用前缀和优化来解决,不过要注意一下求和的上下限。
于是单点转移就被优化到了
代码:
#include <bits/stdc++.h>
#define jl j
#define kl k
typedef long long ll;
const int N = 540, mod = 1000000007;
int n, m;
int c[N], fact[N], finv[N];
int f[N][N], F[N][N];
inline void add(int &x, const int y) {x = (x + y >= mod ? x + y - mod : x + y);}
ll PowerMod(ll a, int n, ll c = 1) {for (; n; n >>= 1, a = a * a % mod) if (n & 1) c = c * a % mod; return c;}
int main() {
int i, j, k, l; ll ans = 0;
scanf("%d%d", &n, &m);
for (i = 1; i <= m; ++i) if (scanf("%d", c + i), !c[i]) --i, --m;
for (*fact = i = 1; i <= n; ++i) fact[i] = (ll)fact[i - 1] * i % mod;
finv[n] = PowerMod(fact[n], mod - 2);
for (i = n; i; --i) finv[i - 1] = (ll)finv[i] * i % mod;
for (**f = i = 1; i <= m; ++i) {
for (j = 0; j <= n; ++j)
for (k = 0; k <= j; ++k)
add(F[j + 1][k + 1] = F[j][k], f[j][k] + (f[j][k] >> 31 & mod)), f[j][k] = 0;
for (jl = 1; jl <= n; ++jl)
for (kl = 1; kl <= jl; ++kl) {
l = std::min(c[i], kl);
f[jl][0] = (f[jl][0] + (ll)finv[kl] * (F[jl][kl] - F[jl - l][kl - l])) % mod;
if (jl != n) {
l = std::min(c[i] - 1, kl),
f[jl][kl] = ((ll)f[jl][kl] - F[jl][kl] + F[jl - l][kl - l]) % mod;
}
}
ans += f[n][0];
}
ans = ans % mod * fact[n] % mod;
printf("%lld\n", ans + (ans >> 63 & mod));
return 0;
}