题解:P10328 [UESTCPC 2024] 卡牌游戏

· · 题解

P10328 [UESTCPC 2024] 卡牌游戏 题解

赛场上想出来了但是不敢写,怕调不出来被队友打。

解决方案

首先容易想到 k 对取牌的过程并没有影响,所以考虑矩阵快速幂,这样就解释了为什么 k\le 10^9

然后考虑矩阵的内容。n\le 4 有很大启发,每种牌都可能有或者没有,使用状压,虽然有 2^{16} 种方案,但容易想到实际合法的方案应该不多。打表发现只有 209 种。可以支持 O(n^3) 的矩阵乘法。

每个状态的得分是可以计算的。其次,要计算任意两个状态之间转移的概率。分几种情况:拿起一张没有的、拿起一张导致放回一张、拿起一张导致放回两张。注意第三种容易被忽视。

这些是算好了以后,就可以上矩阵了,如图:

其中 f_i 表示状态编号为 i 的状态发生的概率。注意因为 ans 只能计算上一次的贡献,因此 k 要变成 k+1

代码

注意细节。

#include<bits/stdc++.h>
using namespace std;typedef long long ll;const int N=220,MOD=998244353;
struct mat{
    ll n,m,a[N][N];
    mat(){n=m=0,memset(a,0,sizeof(a));}
    mat operator *(const mat b){
        mat res;res.n=n,res.m=b.m;
        for(int i=1;i<=res.n;i++)for(int j=1;j<=res.m;j++)
            for(int k=1;k<=m;k++)(res.a[i][j]+=a[i][k]*b.a[k][j]%MOD)%=MOD;
        return res;
    }
}x,res;ll n,sn,k,a[10][10],b[10][10],cnt,num[N],mv[N][N],point[N],f1,f2,tot,sum;
inline ll getid(ll x,ll y){return (x-1)*n+y-1;}
inline ll qpow(ll x,ll y){
    ll res=1;
    while(y){
        if(y&1)(res*=x)%=MOD;
        (x*=x)%=MOD,y>>=1;
    }
    return res;
}
inline bool check(int i,int x,int y){
    for(int j=1;j<=n;j++){
        if(x==-1&&((i>>getid(j,y))&1))return 0;
        else if(y==-1&&((i>>getid(x,j))&1))return 0;
    }
    return 1;
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0),cin>>n>>k,k++,sn=n*n;
    for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)cin>>a[i][j],(tot+=a[i][j])%=MOD;
    for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)cin>>b[i][j];
    for(int i=0;i<(1<<sn);i++){
        f1=1;
        for(int j=1;j<=n;j++)for(int k=1;k<=n;k++)
            if(!a[j][k]&&((i>>getid(j,k))&1)){f1=0;break;}
        if(!f1)continue;
        f1=1;
        for(int j=1;j<=n;j++)for(int k=1;k<=n;k++)for(int l=k+1;l<=n;l++)
            if((((i>>getid(j,k))&1)&&((i>>getid(j,l))&1))||
            (((i>>getid(k,j))&1)&&((i>>getid(l,j))&1))){f1=0;break;}
        if(f1){
            num[++cnt]=i;
            for(int j=1;j<=n;j++)for(int k=1;k<=n;k++)
                point[cnt]+=(((i>>getid(j,k))&1)*b[j][k]);
        }
    }
    for(int i=1;i<=cnt;i++)for(int j=1;j<=cnt;j++){
        f1=f2=-1;
        for(int k=1;k<=n;k++){
            for(int l=1;l<=n;l++)
                if(((num[i]>>getid(k,l))&1)+((num[j]>>getid(k,l))&1)==1){
                    if(f2>=0){f1=-2;break;}
                    else if(f1>=0)f2=getid(k,l);
                    else f1=getid(k,l);
                }
            if(f1==-2)break;
        }
        if(f1<0)continue;
        for(int k=1;k<=n;k++)for(int l=1;l<=n;l++)tot-=((num[i]>>getid(k,l))&1);
        if(f2>=0){
            int x=f1/n+1,y=f1%n+1,k=f2/n+1,l=f2%n+1;
            if(((num[i]>>f1)&1)&&((num[i]>>f2)&1))
                mv[i][j]=(a[x][l]+a[k][y])%MOD*qpow(tot,MOD-2)%MOD;
        }
        else{
            if((num[i]>>f1)&1){
                sum=0;
                for(int k=1;k<=n;k++){
                    if(check(num[i],k,-1)||k==f1/n+1)(sum+=a[k][f1%n+1])%=MOD;
                    if(k!=f1%n+1&&check(num[i],-1,k))(sum+=a[f1/n+1][k])%=MOD;
                }
                mv[i][j]=(sum-1+MOD)%MOD*qpow(tot,MOD-2)%MOD;
            }
            else mv[i][j]=a[f1/n+1][f1%n+1]*qpow(tot,MOD-2)%MOD;
        }
        for(int k=1;k<=n;k++)for(int l=1;l<=n;l++)tot+=((num[i]>>getid(k,l))&1);
    }
    x.n=x.m=res.m=cnt+1,res.n=1,res.a[1][1]=1,x.a[cnt+1][cnt+1]=1;
    for(int i=1;i<=cnt;i++)for(int j=1;j<=cnt;j++)x.a[i][j]=mv[i][j];
    for(int i=1;i<=cnt;i++)x.a[i][cnt+1]=point[i];
    while(k){
        if(k&1)res=res*x;
        x=x*x,k>>=1;
    }
    return cout<<res.a[1][cnt+1]<<"\n",0;
}