题解 「DTOI-5」进行一个排的重 (Maximum Version)

· · 题解

下文设 ap 升序排序后的 q'q_0

  1. 求解 f(a)_{\max}

首先给出结论:答案为 q_0 的 LIS 长度再加上 n

为了证明这一点,首先我们需要证明如下结论:

证明:这里我们考虑从前往后依次调整每一个不满足条件的项。对于不产生贡献的 (p'_i, q'_i),找到离 i 最近的 0 \leq j < i - 1,使得 \displaystyle\max_{k = 1}^j p'_k < p'_i\displaystyle\max_{k = 1}^j q'_k < q'_i。下面讨论满足前一条件的情况。

这里考虑把 a'_i 插入到 a'_j 后面去,考虑一下贡献的变化量:

则单次调整带来的新贡献 \in [1, 2],满足不劣的条件。

于是,我们可以将任意方案调整为一个不劣且每对 p'_i, q'_i 至少有一个产生贡献的情况。

现在我们只需要考虑两个均产生贡献的情况,则:

利用该结论,又因为本题中 n 范围较小,直接暴力 dp 求 LIS 即可。时间复杂度为 O(n^2)

  1. 求解 f(a)_{\max} 的方案数

考虑利用一定存在一个 q_0 的 LIS,使得其顺序地作为每个最优 a' 的子序列的性质 dp。

g_i 表示 q_0 的前缀 [1, i] 的满足最后一个元素为 i 的 LIS 长度。特别地,我们钦定 g_{n + 1} = \displaystyle\max_{i = 1}^n g_i + 1

dp_{i} 表示从后往前考虑到了 [i, n + 1] 这个后缀,i 在我们钦定的某个 LIS 中的方案数。

初值:dp_{n + 1} = 1

转移:dp_i = \displaystyle\sum_{j > i, (q_0)_j > (q_0)_i, g_j = g_i + 1}^{n + 1} dp_j C_{x + y}^x,其中 x 表示满足 1 \leq k < i(q_0)_i < (q_0)_k < (q_0)_jk 的个数,y 表示满足 i < k < j(q_0)_k < (q_0)_ik 的个数。

答案:dp_0

暴力实现是 O(n^3) 的,直接上前缀和优化即可做到 O(n^2)

综上,时间复杂度为 O(n^2)

代码:

#include <iostream>
#include <algorithm>

using namespace std;

typedef long long ll;

const int mod = 998244353;
int dp1[10007], sum[10007][10007];
ll fac[10007], inv_fac[10007], dp2[10007];
pair<int, int> pr[10007];

inline ll quick_pow(ll x, ll p, ll mod){
    ll ans = 1;
    while (p){
        if (p & 1) ans = ans * x % mod;
        x = x * x % mod;
        p >>= 1;
    }
    return ans;
}

inline void init(int n){
    fac[0] = 1;
    for (int i = 1; i <= n; i++){
        fac[i] = fac[i - 1] * i % mod;
    }
    inv_fac[n] = quick_pow(fac[n], mod - 2, mod);
    for (int i = n - 1; i >= 0; i--){
        inv_fac[i] = inv_fac[i + 1] * (i + 1) % mod;
    }
}

inline int get_sum(int l1, int r1, int l2, int r2){
    return sum[r1][r2] - sum[l1 - 1][r2] - sum[r1][l2 - 1] + sum[l1 - 1][l2 - 1];
}

inline ll comb(int n, int m){
    if (n < 0 || m < 0 || n < m) return 0;
    return fac[n] * inv_fac[m] % mod * inv_fac[n - m] % mod;
}

int main(){
    int n, ni;
    cin >> n;
    ni = n + 1;
    init(n);
    for (int i = 1; i <= n; i++){
        cin >> pr[i].first;
    }
    for (int i = 1; i <= n; i++){
        cin >> pr[i].second;
    }
    sort(pr + 1, pr + n + 1);
    pr[ni].second = ni;
    for (int i = 1; i <= ni; i++){
        for (int j = 1; j < i; j++){
            if (pr[i].second > pr[j].second) dp1[i] = max(dp1[i], dp1[j]);
        }
        dp1[i]++;
    }
    for (int i = 1; i <= n; i++){
        for (int j = 1; j <= n; j++){
            sum[i][j] = sum[i][j - 1] + sum[i - 1][j] - sum[i - 1][j - 1];
            if (pr[i].second == j) sum[i][j]++;
        }
    }
    dp2[ni] = 1;
    for (int i = n; i >= 0; i--){
        for (int j = i + 1; j <= ni; j++){
            if (pr[j].second > pr[i].second && dp1[j] == dp1[i] + 1){
                int t = i == 0 ? 0 : get_sum(1, i - 1, pr[i].second + 1, pr[j].second - 1);
                dp2[i] = (dp2[i] + dp2[j] * comb(t + get_sum(i + 1, j - 1, 1, pr[i].second), t) % mod) % mod;
            }
        }
    }
    cout << dp1[ni] + n - 1 << " " << dp2[0];
    return 0;
}