[THUPC2021] 游戏

· · 题解

Nim游戏先手必胜当且仅当所有石子的 Xor 非 0。

也就是说我们要求的是:把n个石子分成m堆,每堆不超过上限且 Xor 非 0 的方案数

而这个东西不太好求,可以转换为求:总方案数-每堆不超过上限且 Xor 为 0 的方案数

总方案数:

\sum\limits_{S}(-1)^{|S|}\times\binom{n-\sum_{i\in S}(a_i+1)+m-1}{m-1}

直接 2^m 枚举子集即可。

考虑怎么求堆不超过上限且 Xor 为 0 的方案数,容易想到数位 Dp,只不过我们现在需要把 m 堆是否顶着上限都要压入状态中。

所以设:f_{i,s,j} 表示从高到低 Dp 到第 i 位,目前被卡上界的集合是 s,低位向这一位进位 j 次,注意 j 的范围只能是 0\to m-1

考虑直接枚举每一个数在这一位填 0/1,复杂度约为 O(4^mm\log n),应该无法通过。

注意到我们只需要枚举目前还顶着上界的,没顶着上界的只需要枚举有几个 1 再乘上组合数即可,这样复杂度就降为了 O(3^mm^2\log n),代码颇多细节。

#include<bits/stdc++.h>
#define N 11
#define ll long long
using namespace std;
const int mod=998244353;
ll n,a[N];
int m,S,c[N][N],f[2][1<<N][N],num[1<<N],cnt[1<<N],t[1<<N][2],tot;
inline int Mod(int x) {return x<mod?x:x-mod;}
inline void Add(int&x,int y) {x=Mod(x+y);}
inline int Ksm(int a,int n,int ans=1) {for(; n; n>>=1,a=1ll*a*a%mod)if(n&1)ans=1ll*ans*a%mod;return ans;}
int C(ll a,int b,int ans=1) {for(int i=1; i<=b; ++i)ans=1ll*ans*((a-i+1)%mod)%mod*Ksm(i,mod-2)%mod;return ans;}
inline int Solve1(int ans=0) {
    for(int i=0; i<(1<<m); ++i) {
        int opt=1;ll S=0;
        for(int j=0; j<m; ++j)if(i>>j&1)opt=-opt,S+=a[j]+1;
        if(S<=n)ans=Mod(Mod(ans+opt*C(n-S+m-1,m-1))+mod);
    } return ans;
}
inline int Solve2(int ans=0) {
    if(n&1)return 0;
    int nw=0;f[nw][S=(1<<m)-1][0]=1;
    for(int i=59; ~i; --i) {
        nw^=1;for(int j=0; j<=S; ++j)for(int k=0; k<m; ++k)f[nw][j][k]=0;
        for(int j=0,tmp; j<=S; ++j)
            for(int k=0; k<m; ++k)if(tmp=f[nw^1][j][k]) {
                int ct=0;for(int l=0; l<m; ++l)ct+=!(j>>l&1);
                for(int l=j,bj,num; 1; l=(l-1)&j) {
                    bj=1;for(int r=0; r<m; ++r)if((j>>r&1)&&!(l>>r&1)&&!(a[r]>>i&1))bj=false;
                    if(bj) {
                        for(int r=num=0; r<m; ++r)if((j>>r&1)&&(l>>r&1)&&(a[r]>>i&1))++num;
                        for(int r=num&1,A; r<=ct; r+=2)
                            if((r+num>>1)<=k&&(A=(k-((r+num)>>1)<<1)+(n>>i&1))<m)Add(f[nw][l][A],1ll*tmp*c[ct][r]%mod);
                    } if(!l)break;
                }
            }
    }
    for(int j=0; j<=S; ++j)ans=Mod(ans+f[nw][j][0]);return ans;
}
int main() {
    cin>>n>>m;
    for(int i=0; i<m; ++i)cin>>a[i];
    for(int i=0; i<=m; ++i)c[i][0]=1;
    for(int i=1; i<=m; ++i)for(int j=1; j<=i; ++j)c[i][j]=Mod(c[i-1][j]+c[i-1][j-1]);
    cout<<Mod(Solve1()-Solve2()+mod);
    return 0;
}