题解:CF2122E Greedy Grid Counting

· · 题解

首先,我们假设所有 a_{i,j} 均为已知数。

首先我们对于一个网格(不考虑其子网格)观察何时贪心走法是优的。

首先行数为 1 的网格是一定对的。我们要考虑的是两行的网格。

如图示,红线为贪心走法,两个黑线是其他的走法。

为了保证红线的最优性,上面的绿框所框住的 a_{1,j} 之和应大于等于下面绿框所框住的 a_{2,j} 之和。同理,下面的蓝框所框住的 a_{2,j} 之和应大于等于上面蓝框所框住的 a_{1,j} 之和。

我们注意到框的位置规律:第一行的框总是比第二行的框要靠后一个单位距离,这启示我们考虑位移 1 的元素之差。

于是,设 v_{i}=a_{1,i+1}-a_{2,i}

现在我们希望知道拐点需要满足什么性质。

假设贪心走法的拐点为 k,则应该有:

\begin{cases} \sum_{j=i}^{k-1}v_j \geq 0 &&& (i<k)\\ \sum_{j=k}^{i-1}v_j \leq 0 &&& (i>k)\\ v_k \leq 0 \end{cases}

同时,我们可以发现当 v_k=0 时,若 v_{k+1} \leq 0,则:

\begin{cases} \sum_{j=i}^{k}v_j \geq 0 &&& (i<k+1)\\ \sum_{j=k+1}^{i-1}v_j \leq 0 &&& (i>k+1)\\ v_{k+1} \leq 0 \end{cases}

也即 k+1 必定优于 k

如果 v_{k+1}>0,显然也是没有必要在这向下拐的。我们可以一直向后考虑到下一个 v_{k'} \leq 0,显然也会满足上面的式子。

于是拐点其实就是第一个 v_k<0

然后我们可以判断网格是否可行。

由于拐点前都有 v_i \geq 0,所以 \sum_{j=i}^{k-1}v_j \geq 0 (i<k) 是一定成立的。

所以判定就考虑 \sum_{j=k}^{i-1}v_j \leq 0(i>k) 即可。

回到原来的问题(假设知道所有 a_{i,j} 且考虑所有子网格)。

上面的形式表明,第一个 v<0 位置相同的大小两个网格,大网格满足时小网格必定满足。

我们考虑先将序列中所有 v_{k_x}<0 提出。他们都要满足 \sum_{j=k_x}^{i-1}v_j \leq 0(i>k_x)。如果我们倒着依次考虑这些点,就可以发现条件可以归纳简化,表示为 \sum_{j=k_x}^{k_{x+1}-1}v_j \leq 0 \Rightarrow -v_{k_x} \geq \sum_{j=k_x+1}^{k_{x+1}-1}v_j

于是我们可以这样判定网格是否合法:从左往右扫,跳过开头的一段极长 v \geq 0,维护一个值 x,如果 v_i<0,则是一个新段的开始,直接令 x \mapsto -v_i,否则 x \mapsto x-v_i。序列合法当且仅当 x 无论何时均有 x \geq 0

考虑计数,扫过去的时候枚举 v_i 的可能值及方案数,用 DP 维护上述判定过程即可。时间复杂度 O(nk^2)

参考构思:

#include<bits/stdc++.h>
const int mod=998244353,inv2=499122177;
using namespace std;
int n,k,x[505],y[505];
int f[2][505],f0;
void tmain(){
    cin>>n>>k;
    for(int i=1;i<=n;i++)cin>>x[i];
    for(int i=1;i<=n;i++)cin>>y[i];
    f0=1;
    for(int i=0;i<=k;i++)f[1][i]=0;
    for(int i=1;i<n;i++){
        swap(f[1],f[0]);
        for(int j=0;j<=k;j++)f[1][j]=0;
        if(x[i+1]!=-1){
            if(y[i]!=-1){
                int a=x[i+1]-y[i];
                if(a>=0){for(int i=0;i+a<=k;i++)(f[1][i]+=f[0][i+a])%=mod;}
                else{
                    a=-a;
                    for(int i=0;i<=k;i++)(f[1][a]+=f[0][i])%=mod;
                    (f[1][a]+=f0)%=mod;
                    f0=0;
                }
            }
            else{
                for(int yi=1;yi<=x[i+1];yi++){
                    int a=x[i+1]-yi;
                    for(int i=0;i+a<=k;i++)(f[1][i]+=f[0][i+a])%=mod;
                }
                for(int yi=x[i+1]+1;yi<=k;yi++){
                    int a=yi-x[i+1];
                    for(int i=0;i<=k;i++)(f[1][a]+=f[0][i])%=mod;
                    (f[1][a]+=f0)%=mod;
                }
                f0=1ll*x[i+1]*f0%mod;
            }
        }
        else{
            if(y[i]!=-1){
                for(int xi=y[i];xi<=k;xi++){
                    int a=xi-y[i];
                    for(int i=0;i+a<=k;i++)(f[1][i]+=f[0][i+a])%=mod;
                }
                for(int xi=1;xi<y[i];xi++){
                    int a=y[i]-xi;
                    for(int i=0;i<=k;i++)(f[1][a]+=f[0][i])%=mod;
                    (f[1][a]+=f0)%=mod;
                }
                f0=1ll*(k-y[i]+1)*f0%mod;
            }
            else{
                for(int a=0;a<=k-1;a++){
                    int ct=k-a;
                    for(int i=0;i+a<=k;i++)(f[1][i]+=1ll*ct*f[0][i+a]%mod)%=mod;
                    if(a==0)continue;
                    for(int i=0;i<=k;i++)(f[1][a]+=1ll*ct*f[0][i]%mod)%=mod;
                    (f[1][a]+=1ll*f0*ct%mod)%=mod;
                }
                f0=1ll*(1+k)*k%mod*inv2%mod*f0%mod;
            }
        }
    }
    int ans=0;
    for(int i=0;i<=k;i++)(ans+=f[1][i])%=mod;
    (ans+=f0)%=mod;
    if(x[1]==-1)ans=1ll*ans*k%mod;
    if(y[n]==-1)ans=1ll*ans*k%mod;
    cout<<ans<<'\n';
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int T;cin>>T;while(T--)tmain();
    return 0;
}