题解 CF451E 【Devu and Flowers】

· · 题解

首先,一句话题意:

s个球,n个盒子,第i个盒子最多能放f[i]个球,也可以不放,问放球的方法总数模10^9+7的模数。

如果盒子没有容纳限制,那么这个题就是一道经典的组合数学题,答案直接C^{n-1}_ {s+n-1},但这道题有f的限制,比较讨厌。

下面介绍两种方法。

1.母函数

这题的母函数很容易表示出来,为

f(x)=(1+x^1+x^2\dots+x^{f[1]})(1+x^1+x^2\dots+x^{f[2]})\dots(1+x^1+x^2\dots+x^{f[n]})

答案则为x^s项的系数,可惜此题中f[i]很大,因此不能使用母函数。

我着重介绍以下的方法

2.状压+组合数+容斥原理+Lucas定理

注意到除了n以外的数都特别大,而n特别小,只有n<=20。这让我们联想到了状态压缩,2^{20}的状态大小可以承受。

因此,我们枚举状态k。第i位表示第i个盒子是否放了超过其容量的球。考虑逆向求解,则答案为方案总数-不合法方案总数

方案总数比较简单,为C^{n-1}_ {s+n-1}。如果第i个盒子放的球数超过了其容量,那么显然,它至少放了f[i]+1个球。

我们将所有多放的球从总数中剔除出去,则剩下m=s-\sum_{i\in k}{(f[i]+1)}个球,且此时已经满足所有属于ki号盒子都超过了其容量。如果有s<m,则跳过这种情况,否则共有C^{n-1}_ {s-m+n-1}种情况。

for (int i=0;i<(1<<n);i++) {
    LL cnt=s;
    for (int j=1;j<=n;j++)
        if (i&(1<<(j-1))) 
            cnt-=f[j]+1;
    if (cnt<0) continue;
    res=(res+C(cnt+n-1,n-1))%mod;
}

但注意,这并不是最终的答案,因为假如我们计算出了一个放满了l个盒子的解,其实这个解的含义是“至少放满了l个盒子”,也就是说,它包含放满l,l+1,l+2\dots n个盒子的答案,明显有重复。

所以我们使用容斥原理,若状态中有奇数个盒子超出了范围,那么就在res中减,否则就加。

for (int i=0;i<(1<<n);i++) {
    LL cnt=s;int del=1;
    for (int j=1;j<=n;j++)
        if (i&(1<<(j-1))) 
            cnt-=f[j]+1,del=-del;
    if (cnt<0) continue;
    res=(res+C(cnt+n-1,n-1)*(LL)del)%mod;
}

最后说一下Lucas定理:解决大组合数取模。注意此题中组合数的第一项是一个可能为10^{14}的数,不能使用常规方法。

LL Lucas(LL a,LL b) {
    if (a<=mod&&b<=mod) return C(a,b);
    return C(a%mod,b%mod)*Lucas(a/mod,b/mod)%mod;//Lucas
}

附上AC代码:

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define LL long long
const LL mod=1e9+7;
LL Pow(LL x,LL y) {
    LL res=1;
    while (y) {
        if (y&1) res=res*x%mod;
        x=x*x%mod,y>>=1;
    }
    return res;
}
LL C(int n,int r) {
    if (n<r) return 0;
    r=min(r,n-r);
    LL res=1,a=1,b=1;
    for (int i=1;i<=r;i++)
        a=a*(n-i+1)%mod,b=b*i%mod;
    res=res*a%mod*Pow(b,mod-2)%mod;
    return res;
}
LL Lucas(LL a,LL b) {
    if (a<=mod&&b<=mod) return C(a,b);
    return C(a%mod,b%mod)*Lucas(a/mod,b/mod)%mod;
}
LL f[21];
LL solve(int n,LL s) {
    LL res=0;
    for (int i=0;i<(1<<n);i++) {
        LL cnt=s;int del=1;
        for (int j=1;j<=n;j++)
            if (i&(1<<(j-1))) 
                cnt-=f[j]+1,del=-del;
        if (cnt<0) continue;
        res=(res+Lucas(cnt+n-1,n-1)*(LL)del)%mod;
    }
    return (res+mod)%mod;
}
int main() {
    int n;LL s;
    while (~scanf("%d%lld",&n,&s)) {
        for (int i=1;i<=n;i++) scanf("%lld",&f[i]);
        printf("%lld\n",solve(n,s));
    }
}