题解:AT_arc148_e [ARC148E] ≥ K

· · 题解

题目让我们计数序列,但是序列有两个端点,每个元素似乎不是等价的。但是我们灵感一下,发现如果在序列里加上一个 $10^9 + 10$,再按照原来的限制对环而非序列计数,那么得到的结果和原问题结果**是一样的**! 于是我们先做一个简单限制:$< \frac{k}{2}$ 的数不能相邻。此时用组合数计算圆排列。设 $< \frac{k}{2}$ 的数有 $A$ 个,$\geq \frac{k}{2}$ 的数有 $B$ 个($A + B = n + 1$),那么简单限制下的总方案数为 ${B \choose A} \times {(B - 1)! \times A!}$。 然后尝试用期望的思想,通过乘一些概率让这个答案变得正确。我们考虑到刚才的方案有些不合法,是因为 $< \frac{k}{2}$ 的某个小数字夹在两个也比较小的 $\geq \frac{k}{2}$ 的数字之间,导致加起来 $< k$。那不合法的概率对于每个 $< \frac{k}{2}$ 的数字是不是独立的呢? 答案是肯定的!这是因为我们考虑从小到大往一个全部 $\geq \frac{k}{2}$ 的环里插入 $< \frac{k}{2}$ 的数字:相当于做了若干次如下问题: > 往一个大小为 $x + y$ 的环里插入一个元素。已知环里有 $x$ 个黑色元素和 $y$ 个白色元素,求被插入的元素两边都是黑色元素的概率。 每做一次这个问题,$x$ 就会减一(被插入过的位置后面不能用了)。而因为我们从小到大插入,已经被插入过的位置两边一定是很大的数字,能接纳接下来插入的所有数字,因此问题之间概率独立。那就做完了。 ```cpp /* K_crane x N_cat */ #include <bits/stdc++.h> #define lowbit(x) ((x) & (-(x))) using namespace std; const int N = 200010, mod = 998244353; inline long long qpow(long long a, long long b) { long long res = 1; while(b) { if(b & 1) res = res * a % mod; b >>= 1, a = a * a % mod; } return res; } long long fac[N], inv[N]; inline long long C(int n, int m) { if(n < m) return 0; return fac[n] * inv[m] % mod * inv[n - m] % mod; } inline void init_math() { fac[0] = inv[0] = 1; for(int i = 1; i <= 200001; ++i) fac[i] = fac[i - 1] * i % mod; inv[200001] = qpow(fac[200001], mod - 2); for(int i = 200000; i >= 1; --i) inv[i] = inv[i + 1] * (i + 1) % mod; } int n, a[N], k; long long res; int main() { // freopen("text.in", "r", stdin); // freopen("prog.out", "w", stdout); ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); init_math(); cin >> n >> k; for(int i = 1; i <= n; ++i) cin >> a[i]; a[++n] = 1e9 + 1; sort(a + 1, a + n + 1); int dv = 0; for(int i = 1; i <= n; ++i) if(a[i] * 2 < k) dv = i; res = C(n - dv, dv) * fac[n - dv - 1] % mod * fac[dv] % mod; for(int i = 1; i <= dv; ++i) { int l = 1, r = n, ps = 0; while(l <= r) { int mid = (l + r) >> 1; if(a[mid] + a[i] >= k) ps = mid, r = mid - 1; else l = mid + 1; } long long cnt = n - ps + 1 - (i - 1); long long all = n - dv - (i - 1); if(cnt == 1 && all == 1); else res = res * (cnt * (cnt - 1) / 2 % mod) % mod * qpow(all * (all - 1) / 2 % mod, mod - 2) % mod; } for(int l = 1; l <= n; ++l) { int r = l; while(r < n && a[r + 1] == a[r]) ++r; res = res * qpow(fac[r - l + 1], mod - 2) % mod; l = r; } cout << res << '\n'; return 0; } /* */ ```