CF1608F MEX counting

· · 题解

一道不错的 dp 练手题(迫真

写之前最好理清楚 dp 的整体思路,否则可能写着写着就晕了。

题意

求有多少个长度为 n 的序列 a 满足对于任意的 1 \le i \le n 都有:

Solution

下文中用 K 表示题目描述中给出的 k

首先第二个条件显然就是让每一个前缀 MEX 都在一个值域区间内,我们设第 i 个区间为 [l_i,r_i]

然后推一推就能设计出状态 dp_{i,j,k} 表示考虑到 i ,目前的 MEXk,之前的位置中有恰好有 j不同的且大于当前 MEX 的值。

满足条件只要时刻保证 l_i \le K \le r_i 即可。

状态这么设计的原因是,我们并不关系那些大于 MEX 的位置的具体的值,因为我们可以等到他对 MEX 会产生影响的时候再去确定它。

j 要描述不同的值是因为同一种值只会对 MEX 造成 1 的影响。

然后我们考虑下面两种情况的转移:

最后贡献答案的时候和第二种情况一样,乘个排列数就行了。

这个 dp 状态 \text{O}(n^3),转移 \text{O}(n)

所以我们有了一个 \text{O}(n^4) 的做法。

优化1

非常显然,dp 状态中 k 这一维有用的状态数其实只有 \text{O}(K) 个。

所以状态复杂度变成 \text{O}(n^2K),转移复杂度变成 \text{O}(K)

复杂度降为 \text{O}(n^2K^2)

优化2

显然复杂度瓶颈在于最后一种转移。

观察就能发现,这个转移所乘的阶乘其实就是转移前后的两个状态的 j 下标,可以直接前缀和。

但是直接搞前缀和如果要保证复杂度正确,边界会变得很复杂。

为了方便,可以用 dp_{i,j+k,k} 来代替 dp_{i,j,k}

所以转移就变成了:

j \times dp_{i,j,k} \rightarrow dp_{i+1,j,k} dp_{i,j,k} \rightarrow dp_{i+1,j+1,k} \dfrac{(j-k)!}{(j+1-t)!} \times dp_{i,j,k} \rightarrow dp_{i+1,j+1,t}

直接平行前缀和就行了。

所以转移复杂度变成了 \text{O}(1),最终时间复杂度为 \text{O}(n^2K)

空间上滚动数组一下就好了。

Code

const int N=2005,K=105,mod=998244353;
int n,k,l[N],r[N],dp[2][N][N],sm[2][N][N],jc[N],inv[N];

inline int q_pow(int a,int b)
{
    int c=1;
    while(b)
    {
        if(b&1) c=a*c%mod;
        a=a*a%mod;b>>=1;
    }
    return c;
}

inline void init(int x)
{
    jc[0]=1;
    for(int i=1;i<=x;++i)
        jc[i]=jc[i-1]*i%mod;
    inv[x]=q_pow(jc[x],mod-2);
    for(int i=x-1;i>=0;--i)
        inv[i]=inv[i+1]*(i+1)%mod;
}

signed main()
{
    init(2000);
    n=read();k=read();
    for(int i=1;i<=n;++i)
    {
        int x=read();
        l[i]=max(0,x-k);
        r[i]=min(i,x+k);
    }
    dp[0][0][0]=1;
    sm[0][0][0]=1;
    for(int I=1,i=1;I<=n;++I,i^=1)
    {
        for(int j=0;j<=I;++j)
            for(int k=l[I];k<=r[I]&&k<=j;++k)
            {
                dp[i][j][k]=(dp[i][j][k]+dp[i^1][j][k]*j%mod)%mod;
                if(j) dp[i][j][k]=(dp[i][j][k]+dp[i^1][j-1][k])%mod;
                if(j&&k) dp[i][j][k]=(dp[i][j][k]+sm[i^1][j-1][min(k-1,r[I-1])]*inv[j-k]%mod)%mod;
                sm[i][j][k]=dp[i][j][k]*jc[j-k]%mod;
                if(k) sm[i][j][k]=(sm[i][j][k]+sm[i][j][k-1])%mod;
            }
        for(int j=0;j<=I-1;++j)
            for(int k=l[I-1];k<=r[I-1]&&k<=j;++k)
                dp[i^1][j][k]=sm[i^1][j][k]=0;
    }
    int ans=0;
    for(int i=0;i<=n;++i)
        for(int j=l[n];j<=r[n]&&j<=i;++j)
            ans=(ans+dp[n&1][i][j]*jc[n-j]%mod*inv[n-i]%mod)%mod;
    write(ans);
    return 0;
}