P7888 「MCOI-06」Distinct Subsequences

· · 题解

总觉得就我一个人没有把问题转换成贡献

dp[i] 为前 i 位的答案。

为了方便理解,我们把每一个 子序列 看作一个集合,(为了方便之后表述,我们定义这个集合的 集合名 为这个子序列)。

集合里的元素是它 本质不同的非空子序列 ,则这个集合的价值就是它的大小,dp[i] 就是 s_{[1,i]} 的所有子序列的 集合大小 之和。

对于 i + 1,设第 i + 1 位上的字符为 p,首先我们可以什么都不做,那么现在的答案依然是 dp[i]

其次,对于 s_{[1,i]} 的任意子序列,可以在其后添加字符 p,构成了若干个新的集合,这些集合的大小之和现在是 dp[i](因为是复制过来的,元素还是那些元素,只是集合名多了个 p)。

容易想到,对于上一段所述的每一个集合,可以在其中每一个元素之后添加字符 p,这样就又得到了 dp[i] 个新的元素。现在总价值有 3 * dp[i]

但是有两个问题需要考虑。

首先,对于一个集合,如果它的集合名第一次加入 p,那么这个集合会多出一个单独的元素 p。这相当于在一个空的子序列之后拼接一个 p,而这是我们需要额外考虑的(因为集合里原先只有 非空子序列)。

假如在 [1, i]p 字符出现了 cnt_p 次,那么一共有 2^{i - cnt_p} 个集合满足这样的条件(对于每一个不是 p 的字符,选或者不选),也就是说说 dp[i + 1] 要额外加上 2^{i-cnt_p}

另外,在元素后添加 p 之后可能会出现重复的元素。比如说原先集合中存在 sa,s,a,倘若添加字符 a,那么就会变成 sa, s, a, saa, sa, aa ,其中 sa 是重复元素。容易发现这样的重复只发生在原先结尾就为 p 的元素上。

我们“退一步”想,假设这个发生冲突的集合是 s_{[1,j]},我们先假设 s_jp,那么每一个结尾为 p 的元素都可以认为是在 j 时刻被加入的。

如果在 j 时刻之前就已经有结尾为 p 的元素,假设为 s1,令子序列 s2 = s1 - p(表示去掉结尾的 p 虽然不清楚有没有这种运算)。

容易知道,要么这个子序列为空,要么这个子序列也同样存在于这个集合中。

那么我们可以认为在 j 时刻在 s2 后拼接上了 p,替换了原来的 s1

(类似于,“此位置已存在名为 s1 且扩展名为 p 的项目。您要用正在移动的项目来替换它吗?“)

也就是说,发生冲突的元素个数就等于 dp[j - 1]

对于 j 之后,不以 p 结尾的集合,由于我们会复制元素,这之中的每个集合也都会发生 dp[j - 1] 次冲突,而这样的集合共有 2^{i - j + 1 - cnt_{j, i, p}} 个。

cnt_{j, i, p} 表示从 ji 有多少个 p。)

综上,dp[i + 1] = 3dp[i] + 2^{i - cnt_{1,i,p}} - \sum_{j \le i, s_j = p}dp[j - 1]*2^{i - j + 1 - cnt_{j,i,p}}

后面那个和式,可以通过对每个字符分别记录一个 del[p] 来维护,在每一位上都进行一次 O(|\Sigma|) 的更新,令 del[s_i] += dp[i],并且对于所有 p \ne s_idel[p] *= 2。总复杂度 O(n|\Sigma|)

#include<algorithm>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<queue>
#include<vector>
using namespace std;
const long long mod = 1000000007;
inline int read() {
    register int x = 0, f = 1; register char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-')f = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){x = x * 10 + ch - 48; ch = getchar();}
    return x * f;
}
template<class T>inline T mmin(register T x, register T y) {
    return x > y ? y : x;
}
template<class T>inline T mmax(register T x, register T y) {
    return x > y ? x : y;
}
int n;
long long dp[1000007];
int lst[37], cnt[37];
long long del[37];
char ch[1000007];
inline long long power(long long a, long long b) {
    long long ret = 1ll;
    for(; b; b >>= 1) {
        if(b & 1) ret = ret * a % mod;
        a = a * a % mod;
    }
    return ret;
}
signed main() {
    scanf("%s", ch + 1);
    n = strlen(ch + 1);
    dp[1] = 1; cnt[ch[1] - 'a'] = 1;
    for(int i = 2; i <= n; ++i) {
        int x = ch[i] - 'a';
        dp[i] = dp[i - 1] * 3 % mod;
        dp[i] = (dp[i] - del[x] % mod + mod) % mod;
        dp[i] = (dp[i] + power(2, i - cnt[x] - 1)) % mod;
        //这里的快速幂是可以避免的,但是这样写方便点(反正不会T
        for(int j = 0; j < 26; ++j) if(j != x) del[j] = del[j] * 2 % mod;
        del[x] = (del[x] + dp[i - 1]) % mod;
        ++cnt[x];
    }
    printf("%lld\n", dp[n]);
    return 0;
}