「解题报告」[ARC148E] ≥ K

· · 题解

好题,全程没看题解做出来一道计数题感觉非常开心x(虽然爆了好几发 好几个 corner cases 没有判)

有一点想提的,就是这种题你可以先去观察答案,发现答案看起来就像是若干个数乘起来而不是加起来的,所以你就可以考虑答案应该是若干个组合数之类的乘积。这也算是一种技巧吧)

思路

题目要求相邻两个数之和大于等于 k,那么首先我们可以把所有的数分成两类,一类是 \ge \frac{k}{2} 的数,一类是 < \frac{k}{2} 的数、

这样分类后,我们可以发现 \ge \frac{k}{2} 的数可以随便放,因为无论如何放两个数之和都 \ge k,而所有 < \frac{k}{2} 的数就都不能相邻着放。

那么我们考虑 \ge \frac{k}{2} 的数与 < \frac{k}{2} 的数之间要怎么放。不难发现,如果放了一个 < \frac{k}{2} 的数 a,那么与之相邻的数必须 \ge k - a

我们可以考虑一种放数的顺序:从小到大考虑 < \frac{k}{2} 的数 a。首先先把所有 \ge k - a 的数全放进去,然后再往里面插 a

假如说我们现在要放一个更大的数 b,那么此时 \ge k - b< k - a 的数必定不可以和 a 放在一起。那么也就是说,一旦这个数 a 放下了,之后无论什么数都不可以与 a 相邻了。

那么我可以考虑随时记录一下当前还有哪些空位还可以放数。当放所有 \ge k - a 的数时,我们相当于把若干个数放到若干个空位中(因为 \ge \frac{k}{2} 可以挨着放),而我们每放一个数,都会多出一个新的空位(因为我们占了一个空位,而这个数左右两边又分别多出一个空位);当放所有 a 时,我们相当于在若干个空位中选出几个空位放数(因为 < \frac{k}{2} 的数不能挨着放),而我们每放一个数,都会减少一个空位。然后把每次选数的方案数乘起来就是答案。

式子其实不难,建议自己写。

具体式子

a 出现的次数为 cnt_a,当前空位为 s

假如我们现在要放一个 \ge \frac{k}{2} 的数 a,那么答案乘以 \displaystyle \binom{cnt_a+s-1}{s-1},并且 s \gets s + cnt_a

假如我们现在要放一个 < \frac{k}{2} 的数 a,那么答案乘以 \displaystyle \binom{s}{cnt_a},并且 s \gets s - cnt_a

代码

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 400005, P = 998244353;
int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
int n, k, a[MAXN];
int fac[MAXN], inv[MAXN];
int C(int n, int m) {
    if (n < m) return 0;
    return 1ll * fac[n] * inv[m] % P * inv[n - m] % P;
}
vector<pair<int, int>> w;
int main() {
    scanf("%d%d", &n, &k);
    fac[0] = 1;
    for (int i = 1; i <= n; i++) 
        fac[i] = 1ll * fac[i - 1] * i % P;
    inv[n] = qpow(fac[n], P - 2);
    for (int i = n; i >= 1; i--)
        inv[i - 1] = 1ll * inv[i] * i % P;
    assert(inv[0] == 1);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    sort(a + 1, a + 1 + n);
    pair<int, int> q = {-1, -1};
    a[0] = -1;
    for (int i = 1; i <= n; i++) { // 处理相同的数字
        if (a[i] != a[i - 1]) {
            if (q.first != -1) {
                w.push_back(q);
            }
            q.first = a[i], q.second = 1;
        } else {
            q.second++;
        }
    }
    w.push_back(q);
    int r = w.size() - 1;
    int mid = -1;
    while (mid + 1 < w.size() && w[mid + 1].first * 2 < k) mid++;
    int space = 1;
    int ans = 1;
    for (int l = 0; l <= mid; l++) { // 从小到大考虑 < k/2 的数
        while (r > mid && w[l].first + w[r].first >= k) { // 加入所有 >= k/2 的数
            ans = 1ll * ans * C(w[r].second + space - 1, space - 1) % P;
            if (!ans) {
                printf("0\n");
                return 0;
            }
            space += w[r].second;
            r--;
        }
        ans = 1ll * ans * C(space, w[l].second) % P;
        if (!ans) {
            printf("0\n");
            return 0;
        }
        space -= w[l].second;
    }
    while (r > mid) { // 把剩下的 >= k/2 的数加进去
        ans = 1ll * ans * C(w[r].second + space - 1, space - 1) % P;
        if (!ans) {
            printf("0\n");
            return 0;
        }
        space += w[r].second;
        r--;
    }
    printf("%d\n", ans);
    return 0;
}
// owo