P16416 题解

· · 题解

考虑在任意时刻 t 数组的形式。容易发现它仅有可能由 \leqslant 2 个连续段构成。具体证明即,下一个不能取的 c_i 只能至多取到 2 个连续段中一个的后面一个。

考虑对这个东西 DP。记 dp_i 表示,填完了 [1, i],后面转移是,先填 [i + 2, j](这段的 ci + 1),再填 i + 1(这个的 c 不能是 i + 1)的贡献。

要求钦定一个数的 c 不是 i + 1 是不好做的,于是考虑容斥。记 dp_{i, j} 表示,填完了 [1, i],有 j 段是合法的,剩下的容斥掉了。转移的时候,如果是不合法的段则另外贡献上 -1 的容斥系数。最后有 j 个合法的段则贡献上 j!。预处理一个区间 [l, r]a, b 贡献即容易做到 O(n ^ 2)

由于最后 t 的形式可能是 2 个连续段,所以说枚举前面 DP 的 [1, i],后面最后一段暴力算即可。时间复杂度 O(n ^ 2)

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define MOD 998244353
int n, a[5005], b[5005], c[5005], cnt[5005], fac[5005], inv[5005];
int qz[5005][5005], dp[5005][5005];
int qpow(int x, int y) {
    int res = 1;
    while (y) {
        if (y & 1) {
            res = res * x % MOD;
        }
        x = x * x % MOD;
        y >>= 1;
    }
    return res;
}
int C(int x, int y) {
    if (x < y || y < 0) {
        return 0;
    }
    return fac[x] * inv[y] % MOD * inv[x - y] % MOD;
}
int A(int x, int y) {
    if (x < y || y < 0) {
        return 0;
    }
    return fac[x] * inv[x - y] % MOD;
}
void init() {
    fac[0] = inv[0] = 1;
    for (int i = 1; i <= 5e3; i++) {
        fac[i] = fac[i - 1] * i % MOD;
        inv[i] = qpow(fac[i], MOD - 2);
    }
}
void solve() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> c[i];
        c[i] = min(c[i], n + 1);
    }
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    for (int i = 1; i <= n; i++) {
        cin >> b[i];
    }
    for (int l = 1; l <= n; l++) {
        cnt[c[l]]++;
        qz[l][l] = (a[l] * l + b[l]) % MOD;
        int v = 1;
        for (int r = l + 1; r <= n; r++) {
            v = (a[r - 1] * r + b[r - 1]) % MOD * v % MOD;
            qz[l][r] = (a[r] * l + b[r]) % MOD * v % MOD;
        }
    }
    dp[0][0] = 1;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j <= i; j++) {
            for (int k = i + 1; k <= i + cnt[i + 1] + 1; k++) {
                if (k > n) {
                    break;
                }
                dp[k][j + 1] = (dp[k][j + 1] + dp[i][j] * A(cnt[i + 1], k - i - 1) % MOD * qz[i + 1][k]) % MOD;
            }
            for (int k = i + 1; k <= i + cnt[i + 1]; k++) {
                if (k > n) {
                    break;
                }
                dp[k][j] = (dp[k][j] - dp[i][j] * A(cnt[i + 1], k - i) % MOD * qz[i + 1][k] % MOD + MOD) % MOD;
            }
        }
    }
    int ans = 0;
    for (int i = 0; i <= n; i++) {
        //end i
        if (cnt[i + 1] >= n - i) {
            int v = A(cnt[i + 1], n - i);
            for (int j = i + 1; j <= n; j++) {
                v = (a[j] * (j + 1) + b[j]) % MOD * v % MOD;
            }
            for (int j = 0; j <= i; j++) {
                ans = (ans + dp[i][j] * fac[j] % MOD * v) % MOD;
            }
        }
    }
    cout << ans << "\n";
}
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    init();
    int tt = 1;
    while (tt--) {
        solve();
    }
    return 0;
}