CF1666F Fancy Stack

· · 题解

upd:补充说明了方程。

来一个和题解区两篇题解不一样的 dp 方法。

首先,两种方案不同当且仅当存在不相同,显然这个是不利于我们 dp 的。一般我们希望两种方案不同当且仅当存在数在原序列的位置不同。因此首先我们首先将题目进行一次转换。假如按后面不同的定义算出来的结果为 ans,实际上的答案为 \dfrac{ans}{\prod (cnt_i!)},其中 cnt_i 表示 i 出现的次数。

我们按顺序填这个序列,一次填一奇一偶。我们设 dp_{i,j} 表示第 i 个偶数位置(即位置 2i)放 a_j 的方案数。那么转移是简单的:(注意 dp_{1,j} 要特殊处理)

dp_{i,j}=\sum_{a_k < a_j}dp_{i-1,k} \times \max(0, scnt_{a_k-1} - (2i - 3))

其中 scntcnt 的前缀和数组,scnt_i 即小于等于 i 的数量。注意这里要取 \max 是因为可能后面那是负数(即没有合法的数可以选择)。

可以参考下面这张图:

直接暴力转移是 O(n^3) 的,在 n \le 5000 的范围下不可以接受,我们考虑优化。

由于序列有序,所以可以发现有用的(也就是后面是正数)k 的取值是一段区间,并且这个区间也是有单调性的(随着 a_j 增大,k 能选的就越多)。

因此每次转移之前处理出 dp_{i-1} 乘后面那坨的前缀和即可。

时间复杂度 O(n ^2)

代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
constexpr int N = 5005, mod = 998244353;
int a[N], dp[N][N], cnt[N], f[N], invf[N], s[N];
inline int qpow(int b, int k)
{
    int res = 1;
    while (k) {
        if (k & 1) res = (ll) res * b % mod;
        b = (ll) b * b % mod;
        k >>= 1;
    }
    return res;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    f[0] = 1;
    for (int i = 1; i <= 5000; ++i) f[i] = (ll) f[i - 1] * i % mod;
    invf[5000] = qpow(f[5000], mod - 2);
    for (int i = 4999; i >= 0; --i) invf[i] = (ll) invf[i + 1] * (i + 1) % mod;
    int T;
    cin >> T;
    while (T--) {
        int n, ans = 0;
        cin >> n;
        for (int i = 1; i <= n; ++i) cnt[i] = 0;
        for (int i = 1; i <= n; ++i) cin >> a[i], ++cnt[a[i]];
        for (int i = 1; i <= n; ++i) cnt[i] += cnt[i - 1];
        for (int i = 1; i <= n; ++i) {
            dp[1][i] = cnt[a[i] - 1];
        }
        for (int i = 2; i <= n / 2; ++i) {
            for (int j = 1; j <= n; ++j) {
                s[j] = (s[j - 1] + (ll) dp[i - 1][j] * (cnt[a[j] - 1] - (2 * i - 3))) % mod;
            }
            int l = 1, r = 0;
            for (int j = 1; j <= n; ++j) {
                while (a[r + 1] < a[j]) ++r;
                while (l <= r && cnt[a[l] - 1] < 2 * i - 3) ++l;
                dp[i][j] = (s[r] - s[l - 1]) % mod;
            }
        }
        for (int i = 1; i <= n; ++i) ans = (ans + dp[n / 2][i]) % mod;
        for (int i = 1; i <= n; ++i) ans = (ll) ans * invf[cnt[i] - cnt[i - 1]] % mod;
        cout << (ans + mod) % mod << '\n';
    }
    return 0;
}