AT_abc273_g [ABC273G] Row Column Sums 2
_•́へ•́╬_
2022-10-21 21:29:43
## 思路
数一下读入的里面,行上有 $cnt_1$ 个 $1$,行上有 $cnt_2$ 个 $2$,列上有 $cnt_3$ 个 $1$,列上有 $cnt_4$ 个 $2$ 的答案。
当且仅当
$$
cnt_1+2\times cnt_2\neq cnt_3+2\times cnt_4
$$
无解。
-----------------
设 $f(i,j,k,l)$ 为行上有 $i$ 个 $1$,行上有 $j$ 个 $2$,列上有 $k$ 个 $1$,列上有 $l$ 个 $2$ 的答案。
考虑转移。
- 当 $l>0$ 时:注意有 $i+2\times j=k+l\times 2$ 成立!
考虑分配一个 列上为 $2$ 的东西:
- 给行上的一个 $1$ 和一个 $2$:$f(i-1+1,j-1,k,l-1)\times i\times j$。(就是那个 $1$ 没了,$2$ 变成 $1$ 了)
- 给行上的两个 $1$:$f(i-2,j,k,l-1)\times C_i^2$。
- 给行上的两个 $2$:$f(i+2,j-2,k,l-1)\times C_j^2$。
- 给行上的一个 $2$:$f(i,j-1,k,l-1)\times j$。
当 $l=0$ 时:注意有 $i+2\times j=k$ 成立!(这点对于理解下面的式子很重要)
答案就是
$$
C_k^1\times C_{k-1}^1\times \cdots\times C_{2\times j+1}^1\times C_{2\times j}^2\times C_{2\times(j-1)}^2\times\cdots\times C_2^2
$$
$$
=\frac{k!}{2^j}
$$
---------------
注意到在转移的过程中 $k$ 始终不变,就是 $cnt_3$。
另外因为始终有那个等式成立(不成立就无解了),所以得到 $i,j,l$ 中的两个就能算出另一个,所以只需要吧其中的两个作为数组下标存着。
复杂度 $\mathcal{O}(n^2)$。
## code
```cpp
#include<stdio.h>
#include<string.h>
#define mod 998244353
inline char nc()
{
static char buf[99999],*l,*r;
return l==r&&(r=(l=buf)+fread(buf,1,99999,stdin),l==r)?EOF:*l++;
}
inline void read(int&x)
{
char c=nc();for(;c<'0'||'9'<c;c=nc());
for(x=0;'0'<=c&&c<='9';x=(x<<3)+(x<<1)+(c^48),c=nc());
}
int n,a,cnt1,cnt2,cnt3,cnt4,ans[5001][5001],fac[5001];
inline long long ksm(long long a,int b)
{
long long ans=1;
for(;b;b>>=1,a*=a,a%=mod)if(b&1)ans*=a,ans%=mod;
return ans;
}
inline long long dfs(const int&i,const int&j,const int&k,const int&l)
{
if(!l)return fac[k]*ksm(ksm(2,j),mod-2)%mod;
if(~ans[i][l])return ans[i][l];
ans[i][l]=0;
if(i&&j)ans[i][l]=(ans[i][l]+dfs(i,j-1,k,l-1)*i%mod*j)%mod;
if(i>1)ans[i][l]=(ans[i][l]+dfs(i-2,j,k,l-1)*(i*(i-1ll)>>1))%mod;
if(j>1)ans[i][l]=(ans[i][l]+dfs(i+2,j-2,k,l-1)*(j*(j-1ll)>>1))%mod;
if(j)ans[i][l]=(ans[i][l]+dfs(i,j-1,k,l-1)*j)%mod;
return ans[i][l];
}
main()
{
fac[0]=1;for(int i=1;i<5001;fac[i]=(long long)(fac[i-1])*i%mod,++i);
memset(ans,-1,sizeof(ans));read(n);
for(int i=n;i--;read(a),a==1&&++cnt1,a==2&&++cnt2);
for(int i=n;i--;read(a),a==1&&++cnt3,a==2&&++cnt4);
if(cnt1+cnt2+cnt2^cnt3+cnt4+cnt4){putchar('0');return 0;}
printf("%lld",dfs(cnt1,cnt2,cnt3,cnt4));
}
```