题解 AT2294【Eternal Average】

· · 题解

题解

首先很容易看出来,答案一定是给每个 1 和 0 配上一个 k^{-x} 的系数再加起来,满足 \sum k^{-x_i}=1,求 Z=\sum_{a_i=1}k^{-x_i} 有多少种取值。

显然需要依次考虑 k 进制下 Z 的每个小数位的取值。设 Z=0.z_1z_2...z_p,z_p>0,假设 \sum_{a_i=1}k^{-x_i} 从低位到高位求和的过程中没有进位,那么肯定满足 \sum z_i=n,若考虑进位,那么满足 \sum z_i\le n。但是 \sum z_i 或者说每一个 z_i 具体的限制究竟是什么?如果你像傻*我一样到了这里就不继续推导而想着优化暴力DP的话,那么你最终会在发现自己的四次方五次方DP怎么都会把方案算重的时候崩溃掉。

其实我们可以很轻易地发现,当产生进位的时候,低位处减少 x*k ,高位处会增加 x,此时 \sum z_i 刚好减少 k-1 的整数倍,也就是说 \sum z_i 仍满足 \sum z_i=n(\bmod k-1)。归纳验证发现,\sum z_i\le n\sum z_i=n(\bmod k-1) 这两条限制已经足够。

还剩下 \sum k^{-x_i}=1 这个条件,我们可以发现它其实等价于 \sum_{a_i=0}k^{-x_i}=1-Z,我们把 1-Z=0.(k-1-z_1)(k-1-z_2)...(k-1-z_{p-1})(k-z_p) 套用上面的推导即可得出 1+(k-1)p-\sum z_i\le m1+(k-1)p-\sum z_i=m(\bmod k-1) 的限制。

然后DP的转移就很明显了,只需要设 dp[i][j][0/1] 表示考虑到小数点后第 i 位,\sum_{x=1}^i z_x=jz_i 等于/大于 0 时的取值个数,转移方程为

dp[i][j][0]=dp[i-1][j][0]+dp[i-1][j][1] dp[i][j][1]=\sum_{x=1}^{k-1}(\,dp[i-1][j-x][0]+dp[i-1][j-x][1]\,)

算答案只需要统计所有 j 满足条件的 dp[i][j][1] 的和即可。

然后就做完了。不需要前缀和优化,因为第一维不超过总操作数,故直接转移复杂度是 O(\frac{n+m-1}{k-1}kn)=O(n^2)

代码

#include<bits/stdc++.h>//JZM yyds!!
#define ll long long
#define uns unsigned
#define IF (it->first)
#define IS (it->second)
#define END putchar('\n')
using namespace std;
const int MAXN=4005;
const ll INF=1e18;
inline ll read(){
    ll x=0;bool f=1;char s=getchar();
    while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
    while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
    return f?x:-x;
}
int ptf[50],lpt;
inline void print(ll x,char c='\n'){
    if(x<0)putchar('-'),x=-x;
    ptf[lpt=1]=x%10;
    while(x>9)x/=10,ptf[++lpt]=x%10;
    while(lpt)putchar(ptf[lpt--]^48);
    if(c>0)putchar(c);
}
inline ll lowbit(ll x){return x&-x;}

const ll MOD=1e9+7;
int n,m,k,p;
ll dp[MAXN][MAXN][2],ans;
inline void ad(ll&a,ll b){a+=b;if(a>=MOD)a-=MOD;}
signed main()
{
    n=read(),m=read(),k=read(),p=(n+m-1)/(k-1);
    dp[0][0][0]=1;
    for(int i=1;i<=p;i++){
        for(int j=0;j<=n;j++){
            dp[i][j][0]=(dp[i-1][j][0]+dp[i-1][j][1])%MOD;
            for(int x=1;x<k&&x<=j;x++)
                ad(dp[i][j][1],dp[i-1][j-x][1]),
                ad(dp[i][j][1],dp[i-1][j-x][0]);
        }
        for(int j=0;j<=n;j++)
            if(j%(k-1)==n%(k-1)&&((k-1)*i+1-j)<=m&&((k-1)*i+1-j)%(k-1)==m%(k-1))
                ad(ans,dp[i][j][1]);
    }
    print(ans);
    return 0;
}