题解 AT2370 【[AGC013D] Piling Up】

ez_lcw

2020-10-17 16:16:27

Solution

考场上用了一种奇怪的做法,不知道为什么就对了,考完后仔细想才想明白。 很巧妙的一种 dp 方式。 首先发现每次操作是拿一个球、放两个球、再拿一个球,总球数不变,所以有 $\text{黑球数}=n-\text{白球数}$。 那么可以设 $dp(i,j)$ 表示 $i$ 次操作后,当前箱子里白球还剩 $j$ 个所得到的的颜色序列的情况数。 显然有状态转移: 1. $dp(0,i)=1$,$0\leq i\leq n$ 2. 当 $i>0$ 时,有: $$ \begin{aligned} dp(i,j)=&\underbrace{dp(i-1,j)}_{\text{先白后黑,需保证$j>0$}}+\underbrace{dp(i-1,j)}_{\text{先黑后白,需保证 $n-j>0$}}+\\ &\underbrace{dp(i-1,j-1)}_{\text{先黑后黑,需保证 $j> 0$}}+\underbrace{dp(i-1,j+1)}_{\text{先白后白,需保证 $n-j>0$}} \end{aligned} $$ 但是发现会算重。 原因是有可能一些相同的操作(即相同的颜色序列)被开始白球数不同的多种情况计算了多次。 ![](https://cdn.luogu.com.cn/upload/image_hosting/uapjhjd0.png) 比如说像上图一样,用一条折线代表一种情况。尽管两种情况初始白球数量不一样,结束时白球数量不一样,但是拿出来的球颜色序列是一样的。 所以需要找到一个恰当的方法,去除这些重复的部分。 考虑钦定只计算某一种特定的方案。 发现选 $0$ 这个点是最合适的,就是说,我们只计算所有颜色序列的情况中,白球数曾经是 $0$ 的那一个。 放到图上来,就是只计算经过 $x$ 轴的那条折线。 显然这种折线只有一条,因为折线不能跨越 $x$ 轴,即白球数量在 dp 过程中不可能小于 $0$。 这种特殊的性质主要取决于我们不需要更新白球数量小于 $0$ 的部分,我们也不需要用白球数量小于 $0$ 的部分来更新其他状态。 然后考虑如何计算这种情况的方案数,最简单的就是直接设 $dp(i,j,1/0)$ 表示前 $i$ 次操作,当前箱子里还剩下 $j$ 个白球,白球数量是/否达到过 $0$ 的情况数。 当前还有好写一点的,也就是我要介绍的这种方法。 设按照我们一开始说的方法 dp,$n=x$ 的情况下的答案是 $solve(x)$,那么答案就是 $solve(n)-solve(n-1)$。 怎么理解呢? $solve(n)$ 表示的是当初始球数为 $0\sim n$ 时的答案。 $solve(n-1)$ 表示的是当初始球数为 $0\sim (n-1)$ 时的答案。但是在这里,我们把它理解为初始球数为 $1\sim n$,且每时每刻箱子里的白球都大于等于 $1$ 个时的答案。也就是说我固定住一个白球在箱子里不动。 那么相减是什么意思呢? ![](https://cdn.luogu.com.cn/upload/image_hosting/d4s6app3.png) 如图,$solve(n)-solve(n-1)$ 其实可以看做图中 $\color{green}{\text{绿线}}$ 与 $\color{purple}{\text{紫线}}$ 围住的情况数减去 $\color{blue}{\text{蓝线}}$ 与 $\color{purple}{\text{紫线}}$ 围住的情况数,那么最后在这些颜色序列相同的情况中,只剩下了唯一的 $\color{red}{\text{红色折线}}$ 所代表的方案。 所以这种方法也是可行的。 代码如下: ```cpp #include<bits/stdc++.h> #define N 3010 #define mod 1000000007 using namespace std; int n,m,dp[N][N]; int solve(int n) { memset(dp,0,sizeof(dp)); for(int i=0;i<=n;i++) dp[0][i]=1; for(int i=0;i<m;i++) { for(int j=0;j<=n;j++) { if(j) dp[i+1][j]=(dp[i+1][j]+dp[i][j])%mod;//先白后黑 if(n-j) dp[i+1][j]=(dp[i+1][j]+dp[i][j])%mod;//先黑后白 if(j) dp[i+1][j-1]=(dp[i+1][j-1]+dp[i][j])%mod;//白白 if(n-j) dp[i+1][j+1]=(dp[i+1][j+1]+dp[i][j])%mod;//黑黑 } } int ans=0; for(int i=0;i<=n;i++) ans=(ans+dp[m][i])%mod; return ans; } int main() { scanf("%d%d",&n,&m); printf("%d\n",(solve(n)-solve(n-1)+mod)%mod); return 0; } ```