题解:P7620 CF1431J Zero-XOR Array

· · 题解

前情提要:LOJ6878 生不逢时。

Solution

我们直接将问题强化为计算

\sum_{\substack{a_1,a_2,\cdots,a_n\\\forall 1\le i\le n,l_i\le a_i\le r_i}}\Big[\bigoplus_{i=1}^n a_i=0\Big]

套路将 [l_i,r_i] 拆成若干个形如 [p2^k,(p+1)2^k) 的区间,缝一个 meet in the middle 即可做到 O\big((2m)^{\frac n2}\big)

考虑优化。先考虑 l_i=0。深挖 x 的拆分的性质。记 f_k(x) 表示 x 在二进制下只保留第 k 位及以上构成的数,g_k(x)=[f_k(x),f_k(x)+2^k),则不难发现每个区间的形式即为 g_k(x\oplus 2^k)。考虑将 x 的拆分与 y 的拆分合并,设我们要合并 g_i(x\oplus 2^i)g_j(y\oplus 2^j),其中 g_i(x\oplus 2^i) 中所有位置的权值均为 pg_j(y\oplus 2^j) 中所有位置的权值均为 q。不难发现这两个区间做异或卷积的结果是 g_{\max(i,j)}\big(f_i(x)\oplus f_j(y)\oplus 2^i\oplus 2^j\big),其中每个元素的权值均为 p\times q\times 2^{\min(i,j)}。发现这好像不是很封闭啊,但是这不重要,我们增加参数 t\in\{0,1\} 并维护每个 g_i(x\oplus t\times 2^i) 的权值即可,这样就封闭了。这部分可以做到 O(nm)

然后将 [l_i,r_i] 拆成 [0,r_i+1)-[0,l_i),直接枚举是 O(2^nm) 的,跟刚刚一样做一个 meet in the middle 即可做到 O(2^{\frac n2}m)

Code

bool Mst;
#include<bits/stdc++.h>
using namespace std;
using ui=unsigned int;
using ll=long long;
using ull=unsigned long long;
using i128=__int128;
using u128=__uint128_t;
using pii=pair<int,int>;
#define fi first
#define se second
constexpr int N=20,T=9,NN=(61<<T)+5,mod=998244353;
inline int add(int x,int y){return (x+=y)>=mod&&(x-=mod),x;}
inline int Add(int &x,int y){return x=add(x,y);}
inline int sub(int x,int y){return (x-=y)<0&&(x+=mod),x;}
inline int Sub(int &x,int y){return x=sub(x,y);}
inline int qpow(int a,int b){
    int res=1;
    for(;b;b>>=1,a=(ll)a*a%mod)
        if(b&1)res=(ll)res*a%mod;
    return res;
}
int n,m,B;ll xs,a[N];
struct node{
    ll v;int f[61][2];
    node(){v=0,memset(f,0,sizeof f);}
    inline int* operator [](int i){return f[i];}
    inline const int* operator [](int i)const{return f[i];}
};
inline ll calc(ll x,int k){return x>>k<<k;}
inline node operator *(const node &a,const node &b){
    node c;c.v=a.v^b.v;
    int sa=0,sb=0,pw=1;
    for(int i=0;i<61;i++){
        Add(c[i][0],((ll)a[i][0]*b[i][0]+(ll)a[i][1]*b[i][1])%mod*pw%mod);
        Add(c[i][1],((ll)a[i][0]*b[i][1]+(ll)a[i][1]*b[i][0])%mod*pw%mod);
        Add(c[i][0],((ll)a[i][0]*sb+(ll)b[i][0]*sa)%mod);
        Add(c[i][1],((ll)a[i][1]*sb+(ll)b[i][1]*sa)%mod);
        Add(sa,(ll)(a[i][0]+a[i][1])*pw%mod);
        Add(sb,(ll)(b[i][0]+b[i][1])*pw%mod);
        Add(pw,pw);
    }
    return c;
}
node L[N],R[N];
inline void work(node &f,ll x){
    f.v=x;
    for(int i=0;i<=m;i++)
        if(x>>i&1)
            Add(f[i][1],1);
}
int ans;
void dfs(int x,node cur,int sgn){
    if(x==n+1){
        ll o=cur.v^xs;
        if(sgn==1){
            for(int i=0;i<61;i++){
                if(!calc(o,i))
                    Add(ans,cur[i][0]);
                if(!(calc(o,i)^1ll<<i))
                    Add(ans,cur[i][1]);
            }
        }
        else{
            for(int i=0;i<61;i++){
                if(!calc(o,i))
                    Sub(ans,cur[i][0]);
                if(!(calc(o,i)^1ll<<i))
                    Sub(ans,cur[i][1]);
            }
        }
        return;
    }
    dfs(x+1,cur*L[x],-sgn);
    dfs(x+1,cur*R[x],sgn);
}
int idx=1;
struct trie{int t[2],w[2],s[2];};
trie tr[NN];
void dfs0(int x,node cur,int sgn){
    if(x==T+1){
        if(sgn==1){
            int p=1,c,sum=0;ll o=cur.v^xs;
            for(int i=0;i<61;i++)
                Add(sum,(cur[i][0]+cur[i][1])*((1ll<<i)%mod)%mod);
            for(int i=60;i>=0;i--){
                c=o>>i&1;
                Add(tr[p].w[0],cur[i][c]);
                Add(tr[p].w[1],cur[i][!c]);
                Add(tr[p].s[c],Sub(sum,(cur[i][0]+cur[i][1])*((1ll<<i)%mod)%mod));
                if(!tr[p].t[c])tr[p].t[c]=++idx;
                p=tr[p].t[c];
            }
        }
        else{
            int p=1,c,sum=0;ll o=cur.v^xs;
            for(int i=0;i<61;i++)
                Add(sum,(cur[i][0]+cur[i][1])*((1ll<<i)%mod)%mod);
            for(int i=60;i>=0;i--){
                c=o>>i&1;
                Sub(tr[p].w[0],cur[i][c]);
                Sub(tr[p].w[1],cur[i][!c]);
                Sub(tr[p].s[c],Sub(sum,(cur[i][0]+cur[i][1])*((1ll<<i)%mod)%mod));
                if(!tr[p].t[c])tr[p].t[c]=++idx;
                p=tr[p].t[c];
            }
        }
        return;
    }
    dfs0(x+1,cur*L[x],-sgn);
    dfs0(x+1,cur*R[x],sgn);
}
void dfs1(int x,node cur,int sgn){
    if(x==n+1){
        if(sgn==1){
            int p=1,c,sum=0;ll o=cur.v;
            for(int i=0;i<61;i++)
                Add(sum,(cur[i][0]+cur[i][1])*((1ll<<i)%mod)%mod);
            for(int i=60;i>=0;i--){
                c=o>>i&1;
                Sub(sum,(cur[i][0]+cur[i][1])*((1ll<<i)%mod)%mod);
                Add(ans,((ll)cur[i][c]*tr[p].s[0]+(ll)cur[i][!c]*tr[p].s[1]+(ll)sum*tr[p].w[c])%mod);
                Add(ans,((ll)cur[i][c]*tr[p].w[0]+(ll)cur[i][!c]*tr[p].w[1])%mod*((1ll<<i)%mod)%mod);
                if(!tr[p].t[c])break;
                p=tr[p].t[c];
            }
        }
        else{
            int p=1,c,sum=0;ll o=cur.v;
            for(int i=0;i<61;i++)
                Add(sum,(cur[i][0]+cur[i][1])*((1ll<<i)%mod)%mod);
            for(int i=60;i>=0;i--){
                c=o>>i&1;
                Sub(sum,(cur[i][0]+cur[i][1])*((1ll<<i)%mod)%mod);
                Sub(ans,((ll)cur[i][c]*tr[p].s[0]+(ll)cur[i][!c]*tr[p].s[1]+(ll)sum*tr[p].w[c])%mod);
                Sub(ans,((ll)cur[i][c]*tr[p].w[0]+(ll)cur[i][!c]*tr[p].w[1])%mod*((1ll<<i)%mod)%mod);
                if(!tr[p].t[c])break;
                p=tr[p].t[c];
            }
        }
        return;
    }
    dfs1(x+1,cur*L[x],-sgn);
    dfs1(x+1,cur*R[x],sgn);
}
bool Med;
int main(){
    cerr<<abs(&Mst-&Med)/1048576.0<<endl;
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n>>m,n--,B=m>>1;
    for(int i=0;i<=n;i++)cin>>a[i],xs^=a[i];
    if(n==0){
        cout<<(xs==0)<<'\n';
        return 0;
    }
    if(m==1){
        cout<<1<<'\n';
        return 0;
    }
    for(int i=1;i<=n;i++)
        work(L[i],a[i-1]),work(R[i],a[i]+1);
    node one;one[0][0]=1;
    if(n<=T)
        dfs(1,one,1);
    else
        dfs0(1,one,1),dfs1(T+1,one,1);
    cout<<ans<<'\n';
    return 0;
}