题解:AT_abc273_g [ABC273G] Row Column Sums 2

· · 题解

很好的 dp 题。

因为 0 \le R_i,C_i \le 2,所以矩阵中的所有元素也一定是在 [0,2] 的范围内。

考虑一个三维的 dp,设 f_{i,j,k} 表示前 i 行有 j 列和为 1,有 k 列和为 2 的方案数。

第一维应该可以滚动数组掉,但是时间复杂度就是 O(n^3) 的,显然不可接受。

发现 k 其实是可以通过 j 推导出来的,因此其实可以实现一个一维的 dp。

f_{i,j} 表示前 i 行有 j 个“列的和为 1”的要求待满足的方案数,可以滚动数组滚掉一维,然后就可以直接状态转移了,这个状态转移方程是好推导的。时间复杂度 O(n^2),空间复杂度 O(n)

但为什么我的代码 WA 了两个点啊(恼。

注意最后如果 sum=j+k\neq 0 的话要输出 0!!!

写得很丑的代码:

#include<bits/stdc++.h>
using namespace std;

#define ll long long 
const int N=7005;
const ll mod=998244353;

ll qpow(ll a,ll b){
    ll res=1;
    while(b){
        if(b&1) res=res*a%mod;
        a=a*a%mod,b>>=1;
    }
    return res;
}
ll inv(ll x){
    return qpow(x,mod-2);
}

int n,r[N],c[N];
ll f[N],g[N],iv2=inv(2);
ll C2(ll x){
    return x*(x-1)%mod*iv2%mod;
}
void solve(){
    cin>>n;
    int cnt1=0,cnt2=0;
    for(int i=1;i<=n;i++) cin>>r[i];
    for(int i=1;i<=n;i++) cin>>c[i],cnt1+=(c[i]==1),cnt2+=(c[i]==2);

    int sum=cnt1+cnt2*2;

    f[cnt1]=1;
    for(int i=1;i<=n;i++){
        memset(g,0,sizeof(g));
        for(int j=0;j<=n;j++){
            if(f[j]==0||sum-j<0||(sum-j)%2) continue;
            int k=(sum-j)/2;
            if(r[i]==0) g[j]=(g[j]+f[j])%mod;
            else if(r[i]==1){
                if(j>=1) g[j-1]=(g[j-1]+f[j]*j%mod)%mod;
                if(k>=1) g[j+1]=(g[j+1]+f[j]*k%mod)%mod;
            }
            else{
                if(k>=1) g[j]=(g[j]+f[j]*k%mod)%mod;
                if(k>=2) g[j+2]=(g[j+2]+f[j]*C2(k)%mod)%mod;
                if(j>=2) g[j-2]=(g[j-2]+f[j]*C2(j)%mod)%mod;
                if(j>=1&&k>=1) g[j]=(g[j]+f[j]*j%mod*k%mod)%mod;
            }
        }
        memcpy(f,g,sizeof(g));
        sum-=r[i];
    }
    if(sum) return cout<<0<<"\n",void();
    return cout<<f[0]<<"\n",void();
}

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    int T=1;
    while(T--){
        solve();
    }

    return 0;
}