题解 P2791 【幼儿园篮球题】

· · 题解

注意:题解所用变量名称和题目中是不同

题意:给定系数 k,多组询问超几何分布 x~H(n,M,N)k 次方期望 E(x^k)

难度定位:简单紫题,文化课选手应该能秒。

前两档部分分分别是套课本期望方差公式和直接暴力枚举投进几个球,用超几何分布公式计算。

超几何分布中, P(x=i)=\frac{\tbinom{M}{i}\tbinom{N-M}{n-i}}{\tbinom{N}{n}}

所以

ans=\sum_{i=0}^ni^k*\frac{\tbinom{M}{i}\tbinom{N-M}{n-i}}{\tbinom{N}{n}}=\sum_{i=1}^ni^k*\frac{\tbinom{M}{i}\tbinom{N-M}{n-i}}{\tbinom{N}{n}}

考虑 i^k 的组合意义是将 k不同的球放入 i不同的盒子中的方案数,设 S_2(k,j) (第二类斯特林数)表示把 k不同的球放入 j相同的盒子且要求盒子非空的方案数,考虑枚举非空盒子数目 j 那么显然有

i^k=\sum_{j=0}^iS_2(k,j)j!\tbinom{i}{j}

代入题中式子可以得到

\sum_{i=1}^ni^k\frac{\tbinom{M}{i}\tbinom{N-M}{n-i}}{\tbinom{N}{n}}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =\frac{\sum_{i=1}^n\tbinom{M}{i}\tbinom{N-M}{n-i}\sum_{j=0}^kS_2(k,j)j!\tbinom{i}{j}}{{\tbinom{N}{n}}}

交换求和次序

=\frac{\sum_{j=0}^kS_2(k,j)j!\sum_{i=1}^n\tbinom{M}{i}\tbinom{i}{j}\tbinom{N-M}{n-i}}{{\tbinom{N}{n}}}

考虑 \tbinom{M}{i}\tbinom{i}{j} 的组合意义是从 M 个人中选出 i 个参加比赛的,再从 i 个人中选出 j 个获奖。

那么我们可以先硬点 j 个人获奖,然后选 i-j 个人0

可以得到 \tbinom{M}{i}\tbinom{i}{j}=\tbinom{M}{j}\tbinom{M-j}{i-j}

因此原式化为下式(与 i 无关扔前面)

\ \ \ \ \ \ \ =\frac{\sum_{j=0}^kS_2(k,j)j!\tbinom{M}{j}\sum_{i=1}^n\tbinom{M-j}{i-j}\tbinom{N-M}{n-i}}{{\tbinom{N}{n}}}

后面这玩意是一个范德蒙德卷积,进一步化简

=\frac{\sum_{j=0}^kS_2(k,j)j!\tbinom{M}{j}\tbinom{N-j}{n-j}}{{\tbinom{N}{n}}}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \

这样只要知道 S_2(k,j) 就可以 O(k) 计算答案了,而O(qk)的复杂度是可以通过这题的。

考虑第二类斯特林数 S_2(k,j) 的组合意义,利用容斥原理可以得到

S_2(k,j)=\frac{1}{j!}\sum_{i=0}^j(-1)^i(j-i)^k\tbinom{j}{i}=\sum_{i=0}^j\frac{(-1)^i(j-i)^k}{i!(j-i)!}

设生成函数 f(x)=\sum_i(-1)^ii!*x^i,g(x)=\sum_i\frac{i^k}{i!}*x^i,显然对于一个给定的 kS_2(k,j)f*gx^j 的系数,可以 O(k+N) 计算 f(x),g(x) (可以用线性筛处理 i^k ),然后 \text{NTT} 就可以 O(klog_2k+N) 计算所需要的第二类斯特林数,从而计算出结果。总复杂度 O(N+k(q+log_2k))

代码

#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=5.5e5,M=2e7+2,p=998244353;
int f[N],g[N],r[N],yg[N],ig[N],ifac[M],fac[M],ss[N],mc[N];
int c;
inline int ksm(register int x,register int y)
{
    register int r=1;
    while (y)
    {
        if (y&1) r=(ll)r*x%p;
        x=(ll)x*x%p;
        y>>=1;
    }
    return r;
}
void dft(register int *a,register int xs,register int limit)
{
    register int i,j,k,l,w,wn,b,c;
    for (i=1;i<limit;i++) if (i<r[i]) swap(a[i],a[r[i]]);
    for (i=1;i<limit;i=l)
    {
        l=i<<1;
        if (xs) wn=ig[l]; else wn=yg[l];
        for (j=0;j<limit;j+=l)
        {
            w=1;
            for (k=0;k<i;k++,w=(ll)w*wn%p)
            {
                b=a[j|k];c=(ll)w*a[j|k|i]%p;
                a[j|k]=(b+c)%p;
                a[j|k|i]=(b-c+p)%p;
            }
        }
    }
    if (xs)
    {
        limit=ksm(limit,p-2);
        for (i=0;i<xs;i++) a[i]=(ll)a[i]*limit%p;
    }
}
inline void read(int &x)
{
    c=getchar();
    while ((c<48)||(c>57)) c=getchar();
    x=c^48;c=getchar();
    while ((c>=48)&&(c<=57))
    {
        x=x*10+(c^48);
        c=getchar();
    }
}
int main()
{
    register int n,m,q,k,i,j,x,limit=1,l=0,gs=0;
    read(n);read(c);read(q);read(k);
    while (limit<=k) limit<<=1,++l;
    n=max(n,limit-1);
    limit<<=1;
    for (i=1;i<limit;i++) r[i]=r[i>>1]>>1|(i&1)<<l;
    ig[limit]=ksm(yg[limit]=ksm(3,(p-1)/limit),p-2);
    for (i=limit>>1;i;i>>=1)
    {
        yg[i]=(ll)yg[i<<1]*yg[i<<1]%p;
        ig[i]=(ll)ig[i<<1]*ig[i<<1]%p;
    }
    l=limit>>1;ifac[0]=ifac[1]=fac[0]=fac[1]=1;
    for (i=2;i<=n;i++) ifac[i]=p-(ll)p/i*ifac[p%i]%p,fac[i]=(ll)fac[i-1]*i%p;
    for (i=2;i<=n;i++) ifac[i]=(ll)ifac[i-1]*ifac[i]%p;
    mc[1]=1;
    for (i=2;i<l;i++)
    {
        if (!mc[i]) mc[ss[++gs]=i]=ksm(i,k);
        for (j=1;(j<=gs)&&(i*ss[j]<l);j++)
        {
            mc[i*ss[j]]=(ll)mc[i]*mc[ss[j]]%p;
            if (i%ss[j]==0) break;
        }
    }
    for (i=0;i<l;i++) if (i&1) f[i]=p-ifac[i]; else f[i]=ifac[i];
    for (i=1;i<l;i++) g[i]=(ll)ifac[i]*mc[i]%p;
    dft(f,0,limit);dft(g,0,limit);
    for (i=0;i<limit;i++) f[i]=(ll)f[i]*g[i]%p;
    dft(f,l,limit);
    while (q--)
    {
        read(n);read(m);read(x);
        j=0;
        for (i=min(min(x,m),k);i;i--) j=(j+(ll)f[i]*ifac[m-i]%p*fac[n-i]%p*ifac[x-i])%p;
        printf("%d\n",int((ll)j*ifac[n]%p*fac[x]%p*fac[m]%p));
    }
}

update:由于时限缩短,此代码已经无法通过此题。请使用高效的阶乘/阶乘逆元预处理手段。