题解:CF2034F2 Khayyam's Royal Decree (Hard Version)

· · 题解

思路

首先有很简单的 \Theta(k^3) 容斥做法,但是这个做法非常不优美,没有利用好每次将权值 \times 2 的性质,因此考虑利用一下这个性质。

我先把 (r_i,b_i) 变成了 (n-r_i,m-b_i),这样变换后的 (r_i,b_i) 的意义就变成了当前选了 r_i 个红宝石,b_i 个蓝宝石。我们称现在的 (r_i,b_i) 为关键点。

考虑组合意义,我们假设一个方案经过了 k 个关键点 p_1,p_2,\dots,p_kp_0 是起点,p_{k+1} 是终点,则该方案的权值是 \sum (2(r_{p_{i}}-r_{p_{i-1}})+(b_{p_{i}}-b_{p_{i-1}})) \times 2^{k+1-i}

然后化简一下,变成:\sum_{i=1}^k (2 \times r_{p_i}+b_{p_i}) \times (2^{k+1-i}-2^{k-i})+(2 \times r_{p_{k+1}}+b_{p_k+1})=\sum_{i=1}^k (2 \times r_{p_i}+b_{p_i}) \times 2^{k-i}+(2 \times r_{p_{k+1}}+b_{p_k+1}),我们发现 2^{k-i} 的组合意义就是在 [i+1,k] 中任选一个标记点集合的方案数。

然后把贡献拆开来,对于 (2 \times r_{p_{k+1}}+b_{p_k+1}) 我们直接计算一下从 (0,0) 走到 (n,m) 的方案数即可,对于剩下的,相当于选一个 i,然后选择一个标记点集合 S,然后计算一下从 (0,0) 开始经过 i 和所有标记点最后到达 (n,m) 的方案数乘上 (2 \times r_i+b_i)

然后就能直接 DP 了,时间复杂度 \Theta(n+m+k^2)

//A tree without skin will surely die.
//A man without face will be alive.
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i,j,k) for(int i=j;i<=k;++i)
#define per(i,j,k) for(int i=j;i>=k;--i)
#define add(x,y) (x=((x+y>=mod)?(x+y-mod):(x+y)))
int const N=1e6+10,M=5e3+10,mod=998244353;
int n,m,k,fac[N],inv[N],r[M],b[M],dp[M],id[M];
inline int qpow(int a,int b){int res=1;while (b){if (b&1) res*=a,res%=mod;a*=a,a%=mod,b>>=1;}return res;}
inline int C(int n,int m){if (n<m || m<0) return 0;return fac[n]*inv[m]%mod*inv[n-m]%mod;}
inline void solve(){
    cin>>n>>m>>k;
    rep(i,1,k) cin>>r[i]>>b[i],r[i]=n-r[i],b[i]=m-b[i],id[i]=i;
    sort(id+1,id+k+1,[](int x,int y){return (r[x]==r[y])?(b[x]<b[y]):(r[x]<r[y]);});
    int div=qpow(C(n+m,m),mod-2),ans=C(n+m,m)*(n*2+m)%mod;
    rep(g,1,k){
        int i=id[g];
        dp[i]=C(r[i]+b[i],r[i])*(r[i]*2+b[i])%mod;
        rep(j,1,k) if (j!=i && r[j]<=r[i] && b[j]<=b[i])
            add(dp[i],C(r[i]-r[j]+b[i]-b[j],r[i]-r[j])*dp[j]%mod)%mod;
        add(ans,C(n-r[i]+m-b[i],n-r[i])*dp[i]%mod);
    }
    cout<<ans*div%mod<<'\n';
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    fac[0]=1;
    rep(i,1,N-1) fac[i]=fac[i-1]*i%mod;
    inv[N-1]=qpow(fac[N-1],mod-2);
    per(i,N-2,0) inv[i]=inv[i+1]*(i+1)%mod;
    int t=1;
    cin>>t;
    while (t--) solve();
    return 0;
}