题解 P5857 【「SWTR-03」Matrix】

· · 题解

Preface

用 cmd 的引言:

OI 中的四类组合问题 : 判定问题,构造问题,极化问题,计数问题。 NOIP 命题风向 : 弱化知识点,强化思维。

这题被 cmd 拿来作 NOIP 模拟赛 T2 了…… orz SWTR。

确实是一道计数好题,考察基本功和思维。

Solution

部分分参考出题人的题解。

将一次操作分解成一行一列来思考。不难发现一行被操作奇数次时才能改变状态。于是若有 ij 列都被操作了奇数次,那么这次操作产生的答案就是:

\binom{n}{i}\binom{m}{j}

注意满足恰好总操作数等于 k 的条件为:0\le i,j\le k,i\equiv k\pmod2,j\equiv k\pmod2

那么枚举 i,j ,不难得到总答案:

\sum_{i=0}^{\min\{n,k\}}\sum^{\min\{m,k\}}_{j=0}\binom{n}{i}\binom{m}{j}[i\equiv k\pmod 2][j\equiv k\pmod 2]

发现两个 \sum 可以乘开,运用乘法分配律就有:

(\sum_{i=0}^{\min\{n,k\}}\binom{n}{i}[i\equiv k\pmod 2])\times(\sum^{\min\{m,k\}}_{j=0}\binom{m}{j}[j\equiv k\pmod 2])

分别计算乘号两边,时间复杂度 O(n)

然而不同的行列状态可能映射到同一个矩阵上,要减去这些重复的方案。

举个例子:

如果当前行/列状态为 1,那么是被操作了奇数次,0 就是操作了偶数次。

下图 i=2,j=2

行/列 1 1
1 0 0
1 0 0

下图 i=0,j=0

行/列 0 0
0 0 0
0 0 0

两种合法的状态不同,但对应了同一个矩阵。再举例子发现,在 (i,j) 合法的情况下,若 (n-i),(m-j)\le k,(n-i)\equiv k\pmod 2,(m-j)\equiv k\pmod 2,则 (n-i,m-j) 也是一个合法的且重复的方案。

统计的时候计算出这些重复的,用总的减去重复的即可。

Notice

Code

本代码需要C++11。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cctype>
#include <cstring>
using namespace std;

template<typename tp>
void read(tp& a)
{
    register tp num=0;register int f=1;register char ch=getchar();
    while(!isdigit(ch) && ch!='-') ch=getchar();
    if(ch=='-') f=-1,ch=getchar();
    do
        num=num*10+int(ch-'0'),ch=getchar();
    while(isdigit(ch));
    a=num*f;
}
template<typename tp,typename...Args>
void read(tp& a,Args&...args){  read(a);read(args...);  }

typedef const int cint;
cint MOD=998244353;
cint MAXN=2e5+5;
int n,m,k;
int fact[MAXN],inv[MAXN];

int qpow(int base,int k)
{
    register int res=1;
    while(k)
    {
        if(k&1) res=1ll*res*base%MOD;
        base=1ll*base*base%MOD;
        k>>=1;
    }
    return res;
}

void calc(cint MAX)
{
    fact[0]=inv[0]=1;
    for(int i=1;i<=MAX;i++) fact[i]=1ll*fact[i-1]*i%MOD;
    for(int i=1;i<=MAX;i++) inv[i]=qpow(fact[i],MOD-2);
}

int C(cint n,cint m)
{
    return 1ll*fact[n]*inv[m]%MOD*inv[n-m]%MOD;
}

void solve(void)
{
    auto get=[&](cint l,cint r,cint N)
    {
        int res=0;
        for(int i=l;i<=r;i+=2)  res=(1ll*res+C(N,i))%MOD;
        return res;
    };
    int ans1=0,ans2=0,ans3=0,ans4=0;
    ans1=get(min(k,n)&1,min(k,n),n);
    ans2=get(min(k,m)&1,min(k,m),m);
    if(n%2==0 && m%2==0)
    {
        ans3=get(max(n-k,0),min(k,n),n);
        ans4=get(max(m-k,0),min(k,m),m);
    }
    printf("%lld\n",(1ll*ans1*ans2%MOD-1ll*ans3*ans4%MOD*inv[2]%MOD+MOD)%MOD);
}

int main()
{
    calc(2e5+2);
    int T;read(T);
    while(T--)
    {
        read(n,m,k);
        solve();
    }
    return 0;
}