$~~~~\!\!~~~~$总数直接容斥可解。对于异或和为零,考虑数位 dp:从高到低位 dp,记录每个变量在高位的取值是否顶满了。
$~~~~\!\!~~~~$对于和的限制,类似背包,记录一维表示还差多少。显然对于一位 $i$ 我们只需要记录比这位更高的位置的信息,也就是记录一维 $x$ 表示还需要 $2^ix$ 的和,在进入下一位的时候更新 $x$。
$~~~~\!\!~~~~$考虑转移,直接的想法是枚举下一位 $m$ 个数的状态,但是这需要 $O(2^m)$ 复杂度转移显然不对:注意到我们在不同数之间并没有特殊限制,而只在意某一位上为 $1$ 的数的个数是否为偶数。因此我们可以在某一位上逐个数做转移(枚举该数当前位是否填 $1$),复杂度就到 $O(m)$ 了。
$~~~~\!\!~~~~$总的来说,就是 $f_{i,j,S,x,0/1}\to f_{i,j+1,S',x',0/1}$,还有 $f_{i,j,S,x,0/1}\to f_{i-1,0,S,x',0}$。具体的 $S',x'$ 就看是否填 $1$ 和 $n$ 下一位是否为 $1$ 即可。
``` cpp
#include<bits/stdc++.h>
#define mod 998244353
#define ll long long
using namespace std;
inline int pls(const int &x,const int &y){return x+y>=mod?x+y-mod:x+y;}
inline int add(int &x,const int &y){return x=pls(x,y);}
inline int sub(int &x,const int &y){return x=pls(x,mod-y);}
ll n; int m;
ll a[15];
int ansS;
int fpow(int a,int b){
if (!b) return 1;
int t=fpow(a,b>>1);
if (b&1) return 1ll*t*t%mod*a%mod;
return 1ll*t*t%mod;
}
int fac[15], inv[15];
void Init(){
fac[0]=inv[0]=1;
for (int i=1;i<=m;i++) fac[i]=1ll*fac[i-1]*i%mod;
inv[m]=fpow(fac[m],mod-2);
for (int i=m-1;i>=1;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
return ;
}
inline int binom(const int &n,const int &m){
int res=1; for (int i=0;i<m;i++) res=1ll*res*(n-i)%mod;
return 1ll*res*inv[m]%mod;
}
int f[70][15][1<<10][40][2];
int main(){
scanf("%lld %d",&n,&m); Init();
for (int i=1;i<=m;i++) scanf("%lld",a+i);
for (int S=0;S<(1<<m);S++){
ll x=n,f=1; for (int i=0;i<m;i++) if (S&(1<<i)) x-=a[i+1]+1, f=pls(mod,-f);
if (x<0) continue;
add(ansS,1ll*f*binom(pls(x%mod,m-1),m-1)%mod);
}
f[60][0][0][0][0]=1;
for (int i=60;i>=0;i--){
for (int j=1;j<=m;j++){
for (int S=0;S<(1<<m);S++){
for (int x=0;x<=m*2;x++){
if ((S&(1<<j-1))||(a[j]&(1ll<<i))){
add(f[i][j][S][x][0],f[i][j-1][S][x+1][1]);
add(f[i][j][S][x][1],f[i][j-1][S][x+1][0]);
}
int S_=S; if (a[j]&(1ll<<i)) S_|=(1ll<<j-1);
add(f[i][j][S_][x][0],f[i][j-1][S][x][0]);
add(f[i][j][S_][x][1],f[i][j-1][S][x][1]);
}
}
}
if (!i) break;
for (int S=0;S<(1<<m);S++)
for (int x=0;x<=m;x++)
add(f[i-1][0][S][x*2+((n>>i-1)&1)][0],f[i][m][S][x][0]);
}
for (int S=0;S<(1<<m);S++) sub(ansS,f[0][m][S][0][0]);
printf("%d",ansS);
return 0;
}
```