题解 AT2070 【[ARC061D] 3人でカードゲーム / Card Game for Three】
command_block
2021-01-12 07:30:23
**题意** : 有三堆牌,分别有 $n_1,n_2,n_3$ 张。牌上写着数字
$1,2,3$ 中的一个。
先从牌堆 $1$ 中抽一张,接下来,牌上写着几就从几号牌堆抽取。
求在所有可能的 $3^{n_1+n_2+n_3}$ 种方案中,先把牌堆 $1$ 抽空的方案数。
答案对 $10^9+7$ 取模。
$n_1,n_2,n_3\leq 3\times 10^5$ ,时限 $\texttt{3s}$。
------------
条件计数的路子无非两条,要么寻找更简洁的充要条件,要么容斥。
题目中的过程比较精巧,考虑寻找充要条件。
把抽取出的牌排成一个序列,显然,每种放置方式都恰好对应一个序列。(构造映射)
但是,由于可能拿不完牌,所以一个序列可能对应很多种方案,具体地,一个长为 $m$ 的序列对应 $3^{n_1+n_2+n_3-m}$ 种方案。(检查反映射)
我们思考对这个序列的约束,可以发现,除了率先将堆 $1$ 拿空之外,没有任何约束。(检查充要条件)
于是,问题就变成了 : 对每个长度,求先将堆 $1$ 拿空的序列个数。
显然,操作序列中一定恰有 $n_1$ 个 $1$ ,且最后一个必须是 $1$。
枚举抽出的非 $1$ 牌个数 $k$ ,方案数为 $\dbinom{k+n_1-1}{k}\sum\limits_{i=0}^k[i\leq n_2][k-i\leq n_3]\dbinom{k}{i}$
解释一下,$\binom{k+n_1-1}{k}$ 表示 $n_1-1$ 个自由的 $1$ 与非 $1$ 混合的方案数,后面的求和是瓜分非 $1$ 牌的方案数。
但糟糕的是,后半部分是一个组合数部分和,这似乎没有什么快速的方法分别求解,考虑递推。
$S(k)=\sum\limits_{i=0}^k[i\leq n_2][k-i\leq n_3]\dbinom{k}{i}$
$=\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k}{i}$
将组合数裂开 ;
$=\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k-1}{i}+\dbinom{k-1}{i-1}$
$=\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k-1}{i}+\sum\limits_{k-n_3-1\leq i\leq n_2-1}\dbinom{k-1}{i}$
$=2S(k-1)-\dbinom{k-1}{k-n_3-1}-\dbinom{k-1}{n_2}$
注意组合数可能不合法,此时值为 $0$。
求出各个 $S(k)$ 之后,答案即为下式 :
$$\sum\limits_{k=0}^{n_2+n_3}3^{n_1+n_2+n_3-k}\dbinom{n_1-1+k}{k}S(k)$$
不看题解玩出来还是有点小激动的……
```cpp
#include<algorithm>
#include<cstdio>
#define ll long long
#define MaxN 900500
using namespace std;
const int mod=1000000007;
ll powM(ll a,int t=mod-2){
ll ret=1;
while(t){
if (t&1)ret=ret*a%mod;
a=a*a%mod;t>>=1;
}return ret;
}
ll fac[MaxN],ifac[MaxN];
ll C(int n,int m){
if (m<0||n<m)return 0;
return fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
void Init(int n)
{
fac[0]=1;
for (int i=1;i<=n;i++)
fac[i]=fac[i-1]*i%mod;
ifac[n]=powM(fac[n]);
for (int i=n;i;i--)
ifac[i-1]=ifac[i]*i%mod;
}
void preS(int n2,int n3,ll *S)
{
S[0]=1;
for (int k=1;k<=n2+n3;k++)
S[k]=(2*S[k-1]-C(k-1,k-1-n3)-C(k-1,n2))%mod;
}
int n1,n2,n3,N;
ll S[MaxN],pw3[MaxN];
int main()
{
scanf("%d%d%d",&n1,&n2,&n3);
Init(N=n1+n2+n3);
preS(n2,n3,S);
ll ans=0,buf=powM(3,N-n1),sav=powM(3);
for (int k=0;k<=n2+n3;k++){
ans=(ans+buf*C(n1+k-1,k)%mod*S[k])%mod;
buf=buf*sav%mod;
}printf("%lld",(ans+mod)%mod);
return 0;
}
```