题解:CF1183H Subsequences (hard version)

· · 题解

题目传送门 CF

题目大意

给出一个长度为 n 的字符串 s,每次操作选出一个 s 的子序列 t 加入到集合 S,其中 t 不能和 S 中已有的元素重复。每次代价为 n - |t|。求出 k 次操作的最小总代价。若 s 没有 k 个子序列,输出 -1

题目分析

看到 n - |t|,可以想到贪心地选最长的子序列,接下来我们的任务就是求出 s 中每种长度的子序列有多少个。

直接考虑子序列很麻烦,考虑正难则反。先算出所有的子序列,一个很简单的 dp:dp_{i, j} = dp_{i - 1, j} + dp_{i - 1, j - 1}。其中 dp_{i, j} 表示 s 中前 i 个字符里面长度为 j 的子序列数量。

然后,我们来去掉重复的子序列。举个例子,在字符串 abbc 中,第一个 b 和第二个 b 重复了,所以 dp_{3, 2}dp_{3, 1} 都要减去个 1。也就是总共的子序列数量要减去包含上一个相同的字符的子序列数量即减去 dp_{pos_{s_i} - 1, j - 1}。其中,pos_{s_i} 表示上一个 s_i 的位置。

所以得出最终的 dp 方程:dp_{i, j} = dp_{i - 1, j} + dp_{i - 1, j - 1} - dp_{pos_{s_i} - 1, j - 1}

无解的情况就是,所有子序列数量小于 k

注意1 \le k \le 10 ^ {12},要开 long long!!!

code

#include <bits/stdc++.h>
#define ft first
#define sd second
#define endl '\n'
#define pb push_back
#define md make_pair
#define gc() getchar()
#define pc(ch) putchar(ch)
#define umap unordered_map
#define pque priority_queue
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 bint;
typedef pair<int, int> pii;
typedef pair<pii, int> pi1;
typedef pair<pii, pii> pi2;
const ll INF = 0x3f3f3f3f;
const db Pi = acos(-1.0);
inline ll read()
{
    ll res = 0, f = 1; char ch = gc();
    while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = gc();
    while (ch >= '0' && ch <= '9') res = (res << 1) + (res << 3) + (ch ^ 48), ch = gc();
    return res * f;
}
inline void write(ll x)
{
    if (x < 0) x = -x, pc('-');
    if (x > 9) write(x / 10);
    pc(x % 10 + '0');
}
inline void writech(ll x, char ch) { write(x), pc(ch); }
const int N = 1e2 + 5;
ll dp[N][N]; // dp[i][j] : 前 i 个字符里面长度为 j 的子序列数量 
int pos[26]; // pos[i] : 字符 i + 'a' 上一次出现的位置 
int main()
{
    int n = read(); ll k = read();
    string s; cin >> s; s = ' ' + s;
    dp[0][0] = 1;
    for (int i = 1; i <= n; i++)
    {
        dp[i][0] = 1;
        for (int j = 1; j <= i; j++)
        {
            dp[i][j] = dp[i - 1][j - 1] + dp[i - 1][j];
            if (pos[s[i] - 'a']) dp[i][j] -= dp[pos[s[i] - 'a'] - 1][j - 1];
        }
        pos[s[i] - 'a'] = i;
    }
    ll sum = 0;
    for (int i = 0; i <= n; i++)
    {
        sum += dp[n][i];
        if (sum >= k) break; // 防止爆 long long 
    } 
    if (sum < k) // 子序列数量不到 k 个 
    {
        puts("-1");
        return 0; 
    }
    ll ans = 0;
    for (int i = n; i >= 0; i--)
    {
        if (k >= dp[n][i])
        {
            k -= dp[n][i];
            ans += 1ll * dp[n][i] * (n - i);
        }
        else
        {
            ans += 1ll * k * (n - i);
            break;
        }
    }
    write(ans);
    return 0;
}