题解 「DTOI-5」进行一个排的重 (Maximum Version)
下文设
- 求解
f(a)_{\max}
首先给出结论:答案为
为了证明这一点,首先我们需要证明如下结论:
- 任何一个满足
\exists 1 \leq i \leq n ,使得p'_i 和q'_i 均不能产生贡献的方案,一定可以调整为\forall 1 \leq i \leq n ,使得p'_i 或q'_i 至少一个产生贡献的方案,且改后不劣。
证明:这里我们考虑从前往后依次调整每一个不满足条件的项。对于不产生贡献的
这里考虑把
- 对于
p' 而言,原p'_i 会带来1 的新贡献。 - 对于
q' 而言,\displaystyle\max_{k = 1}^j q'_k < q'_i 时会带来1 的新贡献,否则\displaystyle\max_{k = 1}^j q'_k < q'_i, \displaystyle\max_{k = 1}^{j + 1} q'_k > q'_i ,于是有q'_{j + 1} > q'_i ,即不会变劣。
则单次调整带来的新贡献
于是,我们可以将任意方案调整为一个不劣且每对
现在我们只需要考虑两个均产生贡献的情况,则:
- 这些项的个数不超过
q_0 的 LIS 长度,否则一定存在更长的 LIS。 - 抓出任意一个
q_0 的 LIS 中的项都产生2 的贡献的a' ,我们可以用上面的方式将其调整至上限。
利用该结论,又因为本题中
- 求解
f(a)_{\max} 的方案数
考虑利用一定存在一个
设
设
初值:
转移:
- 解释一下这个转移:我们每次将满足第一个条件的元素
k 顺序地放入a' 中的原a_i, a_j 之间,再将满足第二个条件的元素插板到其间。
答案:
暴力实现是
综上,时间复杂度为
代码:
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int mod = 998244353;
int dp1[10007], sum[10007][10007];
ll fac[10007], inv_fac[10007], dp2[10007];
pair<int, int> pr[10007];
inline ll quick_pow(ll x, ll p, ll mod){
ll ans = 1;
while (p){
if (p & 1) ans = ans * x % mod;
x = x * x % mod;
p >>= 1;
}
return ans;
}
inline void init(int n){
fac[0] = 1;
for (int i = 1; i <= n; i++){
fac[i] = fac[i - 1] * i % mod;
}
inv_fac[n] = quick_pow(fac[n], mod - 2, mod);
for (int i = n - 1; i >= 0; i--){
inv_fac[i] = inv_fac[i + 1] * (i + 1) % mod;
}
}
inline int get_sum(int l1, int r1, int l2, int r2){
return sum[r1][r2] - sum[l1 - 1][r2] - sum[r1][l2 - 1] + sum[l1 - 1][l2 - 1];
}
inline ll comb(int n, int m){
if (n < 0 || m < 0 || n < m) return 0;
return fac[n] * inv_fac[m] % mod * inv_fac[n - m] % mod;
}
int main(){
int n, ni;
cin >> n;
ni = n + 1;
init(n);
for (int i = 1; i <= n; i++){
cin >> pr[i].first;
}
for (int i = 1; i <= n; i++){
cin >> pr[i].second;
}
sort(pr + 1, pr + n + 1);
pr[ni].second = ni;
for (int i = 1; i <= ni; i++){
for (int j = 1; j < i; j++){
if (pr[i].second > pr[j].second) dp1[i] = max(dp1[i], dp1[j]);
}
dp1[i]++;
}
for (int i = 1; i <= n; i++){
for (int j = 1; j <= n; j++){
sum[i][j] = sum[i][j - 1] + sum[i - 1][j] - sum[i - 1][j - 1];
if (pr[i].second == j) sum[i][j]++;
}
}
dp2[ni] = 1;
for (int i = n; i >= 0; i--){
for (int j = i + 1; j <= ni; j++){
if (pr[j].second > pr[i].second && dp1[j] == dp1[i] + 1){
int t = i == 0 ? 0 : get_sum(1, i - 1, pr[i].second + 1, pr[j].second - 1);
dp2[i] = (dp2[i] + dp2[j] * comb(t + get_sum(i + 1, j - 1, 1, pr[i].second), t) % mod) % mod;
}
}
}
cout << dp1[ni] + n - 1 << " " << dp2[0];
return 0;
}