Bubble Sort 题解

· · 题解

首先,对于冒泡排序的轮数,有一个结论。

f_i 为在 [1,i) 中,比 i 位置上的数大的数的个数,那么冒泡排序的轮数就为 \max f_i,且一种 f 唯一对应一种排列。

所以可以考虑数所有 f 的数量,就可以得到所有排列数量。

题目中的 b_i 就是 f 的前缀最大值,所以 b_i 单调递增。

考虑如何刻画这个限制,如果有 yb\le k,那么 b_y 一定是 \le k 的,因为它单调递增。

所以要满足 \le kb 的个数 \ge l,其实就是 b[1,l] 这个前缀都要 \le k,也就是 f[1,l] 这个前缀都要 \le k,可以把这个限制记为 (l,k)

然后考虑 \le r 的限制,可以考虑容斥,将 \le r 的限制变成 \ge r + 1,就可以记为 (r+1,k),容斥系数为 -1

具体如何操作呢?

考虑把限制按照 k 从小到大排序,此时可以发现,如果一个前缀 [1,x] 已经被限制 \le k 了,那么下一个限制 (x',k') 只需要满足 (x,x']\le k',因为 [1,x] 一定是 \le k' 的。

所以可以考虑 dp。

dp_{i,j} 表示当前是第 i 个限制,[1,j] 的前缀都被限制了的方案数。

记当前限制为 (x,k)

x\le j,那么有 dp_{i,j}\leftarrow dp_{i-1,j}

否则就要限制 (j,x]\le k,所以有 dp_{i,x}\leftarrow dp_{i-1,j}\times\prod_{p=j+1}^x\min(p,k+1)(这里和 p 取最小值是因为要满足 0\le f_p\lt p)。

这里单次转移可以预处理阶乘、阶乘逆元和使用快速幂做到 O(\log n)

dp_{i,j} 的第二维离散化,复杂度 O(m^2\log n)

#include <bits/stdc++.h>

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 1e6 + 10, M = 1e3 + 10, inf = 1e9, mod = 998244353;

int add( int x, int v) {
    x += v;
    return x >= mod ? x - mod : x;
}

int n, m;

struct lim {
    int k, l, r;
} q[M];

int ind[2 * M], tot;
int fac[N], inv[N];
int f[2 * M], g[2 * M];

int ID( int x) {
    return lower_bound(ind + 1, ind + tot + 1, x) - ind;
}

int ksm( int a, int b, int res = 1) {
    while (b) {
        if (b & 1) res = 1ll * res * a % mod;
        a = 1ll * a * a % mod, b >>= 1;
    }

    return res;
}

int cal( int l, int r, int k) {
    if (k >= r) {
        return 1ll * fac[r] * inv[l] % mod;
    } else if (k <= l) {
        return ksm(k, r - l);
    } else {
        return 1ll * ksm(k, r - k) * fac[k] % mod * inv[l] % mod;
    }
}

void solve() {
    cin >> n >> m;

    tot = 0, ind[++ tot] = 0;

    for ( int i = 1; i <= m; i ++) {
        cin >> q[i].k >> q[i].l >> q[i].r, q[i].r += 1;
        ind[++ tot] = q[i].l;
        if (q[i].r <= n) ind[++ tot] = q[i].r;
    }

    sort(q + 1, q + m + 1, [&]( lim a, lim b) {
        return a.k < b.k;
    });
    sort(ind + 1, ind + tot + 1), tot = unique(ind + 1, ind + tot + 1) - ind - 1;

    f[1] = 1;

    for ( int o = 1; o <= m; o ++) {
        int k = q[o].k, l = q[o].l, r = q[o].r;
        for ( int i = 1; i <= tot; i ++) g[i] = 0;

        int L = ID(l), R = ID(r);

        for ( int i = 1; i <= tot; i ++) {
            if (ind[i] < l)
                g[L] = add(g[L], 1ll * f[i] * cal(ind[i], l, k + 1) % mod);
            else g[i] = add(g[i], f[i]);

            if (r > n) continue ;

            if (ind[i] < r)
                g[R] = add(g[R], add(-1ll * f[i] * cal(ind[i], r, k + 1) % mod, mod));
            else g[i] = add(g[i], mod - f[i]);
        }   

        for ( int i = 1; i <= tot; i ++) f[i] = g[i];
    }

    int ans = 0;
    for ( int i = 1; i <= tot; i ++)
        ans = add(ans, 1ll * f[i] * cal(ind[i], n, n) % mod), f[i] = 0;

    cout << ans << '\n';
}

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    inv[0] = fac[0] = 1;
    for ( int i = 1; i < N; i ++) fac[i] = 1ll * fac[i - 1] * i % mod, inv[i] = ksm(fac[i], mod - 2);

    int T; cin >> T;
    while (T --) solve();

    return 0;
}