题解 P2791 【幼儿园篮球题】

· · 题解

题面传送门。

题意:开始给定 N, M, S, LS 次询问,给定 n, m, k,表示有 m1n-m0,在这当中随机选 k 个数,问选出数和的 L 次方的期望值。答案对 998244353 取模。

推式子大法好!

Solution:

考虑每个 x^L 的贡献,这等价于从中要选出 x1k-x0,于是其贡献为 x^L\dbinom{m}{x}\dbinom{n-m}{k-x}

那么这个题目实际上转化为求:

\frac{\sum\limits_{x=0}^kx^L\dbinom{m}{x}\dbinom{n-m}{k-x}}{\dbinom{n}{k}}

分母是关于 x 的常数,所以考虑一下分子就行。

发现 x^L 这个东西很不好看,考虑斯特林数。记 S(n, m) 表示第二类斯特林数:

x^L=\sum_{i=0}^x\binom{x}{i}i!S(L, i)

证明考虑组合意义就行。

代回原式:

&=\sum_{x=0}^k\sum_{i=0}^x\binom{m}{x}\binom{n-m}{k-x}\binom{x}{i}i!S(L, i) \\&=\sum_{i=0}^k\sum_{x=i}^k\binom{m}{x}\binom{n-m}{k-x}\binom{x}{i}i!S(L, i) \\&=\sum_{i=0}^kS(L, i)i!\sum_{x=i}^k\binom{m}{i}\binom{m-i}{x-i}\binom{n-m}{k-x} \\&=\sum_{i=0}^kS(L, i)i!\binom{m}{i}\sum_{x=0}^{k-i}\binom{m-i}{x}\binom{n-m}{k-i-x} \\&=\sum_{i=0}^kS(L, i)i!\binom{m}{i}\binom{n-i}{k-i} \end{aligned}

(如果上述推式子过程有不了解的知识点,可以戳这)

又有:

S(L, i)&=\frac{1}{i!}\sum_{j=0}^i(-1)^j\binom{i}{j}(i-j)^L \\&=\sum_{j=0}^i\frac{(-1)^j}{j!}\cdot\frac{(i-j)^L}{(i-j)!} \end{aligned}

是个卷积的形式,于是可以在 L\log L 的时间内预处理出 S(L, 0), S(L, 1), ...S(L, L)

over,时间复杂度为 O(L\log L +SL)

Code:

实现的时候注意一下常数问题,不过我好像没卡过常?(

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <map>
#include <set>
using namespace std;
#define ll long long
#define ull unsigned long long
#define fi first
#define se second
#define dingyi int mid = l + r >> 1, ls = p << 1, rs = p << 1 | 1
#define y0 y_csyakioi_0
#define y1 y_csyakioi_1
#define rep(i, x, y) for(int i = x; i <= y; ++i)
#define per(i, x, y) for(int i = x; i >= y; --i)
#define repg(i, u) for(int i = head[u]; i; i = e[i].nxt)
inline int read(){
    int x = 0, f = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){ if(ch == '-') f = -1; ch = getchar(); }
    while(ch >= '0' && ch <= '9'){ x = x * 10 + (ch ^ 48); ch = getchar(); }
    return x * f;
}
const int N = 800010, M = 20000010, mod = 998244353, g_ = 3, invg = 332748118;
int n, m, s, l, nn, mm, kk, S[N], B[N], e[N], r[N], rev, fac[M], inv[M];
inline int fpow(int x, int p){ int ans = 1; for(; p; p >>= 1, x = 1ll * x * x % mod) if(p & 1) ans = 1ll * ans * x % mod; return ans; }
inline void clr(int *A, int x){ memset(A, 0, sizeof(int)*x); }
inline void cpy(int *A, int *B, int x){ memcpy(A, B, sizeof(int)*x); }
inline void NTT(int typ, int *a, int lim){
    if(rev != lim){ rev = lim; rep(i, 0, lim - 1) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0); }
    rep(i, 0, lim - 1) if(i < r[i]) swap(a[i], a[r[i]]);
    for(int mid = 1; mid < lim; mid <<= 1){
        int R = mid << 1, rt = fpow(typ == 1 ? g_ : invg, (mod - 1) / R);
        for(int j = 0; j < lim; j += R){
            int w = 1;
            for(int k = 0; k < mid; ++k, w = 1ll * w * rt % mod){
                int x = a[j | k], y = 1ll * w * a[j | k | mid] % mod;
                a[j | k] = (x + y) % mod; a[j | k | mid] = (x - y + mod) % mod;
            }
        }
    }
    if(typ < 0){ int k = fpow(lim, mod - 2); rep(i, 0, lim - 1) a[i] = 1ll * a[i] * k % mod; }
}
inline void ptm(int *A, int *B, int len){ rep(i, 0, len - 1) A[i] = 1ll * A[i] * B[i] % mod; }
inline void Polymul(int *A, int *B, int len, int m){
    int lim = 1; while(lim < len) lim <<= 1;
    cpy(e, B, lim); clr(e + m, lim - m);
    NTT(1, A, lim); NTT(1, e, lim); ptm(A, e, lim);
    NTT(-1, A, lim); clr(A + m, lim - m); 
}
inline void initfac(int n){
    fac[0] = 1;
    rep(i, 1, n) fac[i] = 1ll * fac[i - 1] * i % mod;
    inv[n] = fpow(fac[n], mod - 2);
    per(i, n - 1, 0) inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;
}
inline int C(int n, int m){ return n < m ? 0 : 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod; }
inline int invC(int n, int m){ return 1ll * inv[n] * fac[m] % mod * fac[n - m] % mod; }
inline void mian(){
    n = read(); m = read(); s = read(); l = read(); initfac(max(n, l));
    rep(i, 0, l) S[i] = (i & 1) ? mod - inv[i] : inv[i];
    rep(i, 0, l) B[i] = 1ll * fpow(i, l) * inv[i] % mod;
    Polymul(S, B, l + 1 << 1, l + 1);
    while(s--){
        nn = read(); mm = read(); kk = read(); int ans = 0;
        rep(i, 0, min(min(mm, kk), l)) ans = (ans + 1ll * S[i] * fac[i] % mod * C(mm, i) % mod * C(nn - i, kk - i) % mod) % mod;
        printf("%d\n", 1ll * ans * invC(nn, kk) % mod); 
    }
}
int main(){ int qwq = 1; while(qwq--) mian(); return 0; }