AT_abc273_g [ABC273G] Row Column Sums 2

_•́へ•́╬_

2022-10-21 21:29:43

Solution

## 思路 数一下读入的里面,行上有 $cnt_1$ 个 $1$,行上有 $cnt_2$ 个 $2$,列上有 $cnt_3$ 个 $1$,列上有 $cnt_4$ 个 $2$ 的答案。 当且仅当 $$ cnt_1+2\times cnt_2\neq cnt_3+2\times cnt_4 $$ 无解。 ----------------- 设 $f(i,j,k,l)$ 为行上有 $i$ 个 $1$,行上有 $j$ 个 $2$,列上有 $k$ 个 $1$,列上有 $l$ 个 $2$ 的答案。 考虑转移。 - 当 $l>0$ 时:注意有 $i+2\times j=k+l\times 2$ 成立! 考虑分配一个 列上为 $2$ 的东西: - 给行上的一个 $1$ 和一个 $2$:$f(i-1+1,j-1,k,l-1)\times i\times j$。(就是那个 $1$ 没了,$2$ 变成 $1$ 了) - 给行上的两个 $1$:$f(i-2,j,k,l-1)\times C_i^2$。 - 给行上的两个 $2$:$f(i+2,j-2,k,l-1)\times C_j^2$。 - 给行上的一个 $2$:$f(i,j-1,k,l-1)\times j$。 当 $l=0$ 时:注意有 $i+2\times j=k$ 成立!(这点对于理解下面的式子很重要) 答案就是 $$ C_k^1\times C_{k-1}^1\times \cdots\times C_{2\times j+1}^1\times C_{2\times j}^2\times C_{2\times(j-1)}^2\times\cdots\times C_2^2 $$ $$ =\frac{k!}{2^j} $$ --------------- 注意到在转移的过程中 $k$ 始终不变,就是 $cnt_3$。 另外因为始终有那个等式成立(不成立就无解了),所以得到 $i,j,l$ 中的两个就能算出另一个,所以只需要吧其中的两个作为数组下标存着。 复杂度 $\mathcal{O}(n^2)$。 ## code ```cpp #include<stdio.h> #include<string.h> #define mod 998244353 inline char nc() { static char buf[99999],*l,*r; return l==r&&(r=(l=buf)+fread(buf,1,99999,stdin),l==r)?EOF:*l++; } inline void read(int&x) { char c=nc();for(;c<'0'||'9'<c;c=nc()); for(x=0;'0'<=c&&c<='9';x=(x<<3)+(x<<1)+(c^48),c=nc()); } int n,a,cnt1,cnt2,cnt3,cnt4,ans[5001][5001],fac[5001]; inline long long ksm(long long a,int b) { long long ans=1; for(;b;b>>=1,a*=a,a%=mod)if(b&1)ans*=a,ans%=mod; return ans; } inline long long dfs(const int&i,const int&j,const int&k,const int&l) { if(!l)return fac[k]*ksm(ksm(2,j),mod-2)%mod; if(~ans[i][l])return ans[i][l]; ans[i][l]=0; if(i&&j)ans[i][l]=(ans[i][l]+dfs(i,j-1,k,l-1)*i%mod*j)%mod; if(i>1)ans[i][l]=(ans[i][l]+dfs(i-2,j,k,l-1)*(i*(i-1ll)>>1))%mod; if(j>1)ans[i][l]=(ans[i][l]+dfs(i+2,j-2,k,l-1)*(j*(j-1ll)>>1))%mod; if(j)ans[i][l]=(ans[i][l]+dfs(i,j-1,k,l-1)*j)%mod; return ans[i][l]; } main() { fac[0]=1;for(int i=1;i<5001;fac[i]=(long long)(fac[i-1])*i%mod,++i); memset(ans,-1,sizeof(ans));read(n); for(int i=n;i--;read(a),a==1&&++cnt1,a==2&&++cnt2); for(int i=n;i--;read(a),a==1&&++cnt3,a==2&&++cnt4); if(cnt1+cnt2+cnt2^cnt3+cnt4+cnt4){putchar('0');return 0;} printf("%lld",dfs(cnt1,cnt2,cnt3,cnt4)); } ```