题解 AT2070 【[ARC061D] 3人でカードゲーム / Card Game for Three】

command_block

2021-01-12 07:30:23

Solution

**题意** : 有三堆牌,分别有 $n_1,n_2,n_3$ 张。牌上写着数字 $1,2,3$ 中的一个。 先从牌堆 $1$ 中抽一张,接下来,牌上写着几就从几号牌堆抽取。 求在所有可能的 $3^{n_1+n_2+n_3}$ 种方案中,先把牌堆 $1$ 抽空的方案数。 答案对 $10^9+7$ 取模。 $n_1,n_2,n_3\leq 3\times 10^5$ ,时限 $\texttt{3s}$。 ------------ 条件计数的路子无非两条,要么寻找更简洁的充要条件,要么容斥。 题目中的过程比较精巧,考虑寻找充要条件。 把抽取出的牌排成一个序列,显然,每种放置方式都恰好对应一个序列。(构造映射) 但是,由于可能拿不完牌,所以一个序列可能对应很多种方案,具体地,一个长为 $m$ 的序列对应 $3^{n_1+n_2+n_3-m}$ 种方案。(检查反映射) 我们思考对这个序列的约束,可以发现,除了率先将堆 $1$ 拿空之外,没有任何约束。(检查充要条件) 于是,问题就变成了 : 对每个长度,求先将堆 $1$ 拿空的序列个数。 显然,操作序列中一定恰有 $n_1$ 个 $1$ ,且最后一个必须是 $1$。 枚举抽出的非 $1$ 牌个数 $k$ ,方案数为 $\dbinom{k+n_1-1}{k}\sum\limits_{i=0}^k[i\leq n_2][k-i\leq n_3]\dbinom{k}{i}$ 解释一下,$\binom{k+n_1-1}{k}$ 表示 $n_1-1$ 个自由的 $1$ 与非 $1$ 混合的方案数,后面的求和是瓜分非 $1$ 牌的方案数。 但糟糕的是,后半部分是一个组合数部分和,这似乎没有什么快速的方法分别求解,考虑递推。 $S(k)=\sum\limits_{i=0}^k[i\leq n_2][k-i\leq n_3]\dbinom{k}{i}$ $=\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k}{i}$ 将组合数裂开 ; $=\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k-1}{i}+\dbinom{k-1}{i-1}$ $=\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k-1}{i}+\sum\limits_{k-n_3-1\leq i\leq n_2-1}\dbinom{k-1}{i}$ $=2S(k-1)-\dbinom{k-1}{k-n_3-1}-\dbinom{k-1}{n_2}$ 注意组合数可能不合法,此时值为 $0$。 求出各个 $S(k)$ 之后,答案即为下式 : $$\sum\limits_{k=0}^{n_2+n_3}3^{n_1+n_2+n_3-k}\dbinom{n_1-1+k}{k}S(k)$$ 不看题解玩出来还是有点小激动的…… ```cpp #include<algorithm> #include<cstdio> #define ll long long #define MaxN 900500 using namespace std; const int mod=1000000007; ll powM(ll a,int t=mod-2){ ll ret=1; while(t){ if (t&1)ret=ret*a%mod; a=a*a%mod;t>>=1; }return ret; } ll fac[MaxN],ifac[MaxN]; ll C(int n,int m){ if (m<0||n<m)return 0; return fac[n]*ifac[m]%mod*ifac[n-m]%mod; } void Init(int n) { fac[0]=1; for (int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod; ifac[n]=powM(fac[n]); for (int i=n;i;i--) ifac[i-1]=ifac[i]*i%mod; } void preS(int n2,int n3,ll *S) { S[0]=1; for (int k=1;k<=n2+n3;k++) S[k]=(2*S[k-1]-C(k-1,k-1-n3)-C(k-1,n2))%mod; } int n1,n2,n3,N; ll S[MaxN],pw3[MaxN]; int main() { scanf("%d%d%d",&n1,&n2,&n3); Init(N=n1+n2+n3); preS(n2,n3,S); ll ans=0,buf=powM(3,N-n1),sav=powM(3); for (int k=0;k<=n2+n3;k++){ ans=(ans+buf*C(n1+k-1,k)%mod*S[k])%mod; buf=buf*sav%mod; }printf("%lld",(ans+mod)%mod); return 0; } ```