CF1119H Triple 题解

· · 题解

不妨将三元组推广至更一般的情况:给出 d_i(0 \le i \lt m)a_{i,j}(1 \le i \le n,0 \le j \lt m,0 \le a_{i,j} \lt 2^k),令集合幂级数 f_i=\sum\limits_{i=0}^{m-1} d_ix^{a_{i,j}},求所有 f_i 的异或卷积 g=\prod\limits_{i=1}^n f_i

F_i=\text{FWT}(f_i)G=\text{FWT}(g),那么有 G=\prod\limits_{i=1}^n F_i。考虑对于每一位 p 求出 [x^p]\prod\limits_{i=1}^n F_i=\prod\limits_{i=1}^n [x^p] F_i,这样即可求出 G,再做一个 IFWT 就可以得到 g 了。

由于 [x^p]F_i=\sum\limits_{j=0}^{m-1} (-1)^{\text{popcount}(p \land a_{i,j})} d_jd_j 前的系数只有 \pm 1,因此 [x^p]F_i 的取值只有 2^m 种,于是考虑求出 [x^p]\prod\limits_{i=1}^n F_i 中每种取值的出现次数,再快速幂得到 G_p

接下来,我们对于每一位 p,考虑如何求出 [x^p]\prod\limits_{i=1}^n F_i 中每种取值的出现次数。

c_s(0 \le s \lt 2^m) 表示第 p 位中 \sum\limits_{j=0}^{m-1} (-1)^{[2^j \subseteq s]} d_j 的出现次数,那么有 c_s=\sum\limits_{i=1}^n\prod\limits_{j=0}^{m-1} \left[ (-1)^{\text{popcount}(p \land a_{i,j})}=(-1)^{[2^j\subseteq s]} \right],不过这样计算的时间复杂度无法接受,我们需要找到 2^m 个线性无关的方程来求解 c。这启发我们对于每一个 T \subseteq \{0,1,\dots,m-1\},仅将 z_{i,T}=\bigoplus\limits_{j \in T} a_{i,j} 代入集合幂级数。令集合幂级数 h_T=\sum\limits_{i=1}^n x^{z_{i,T}}H_T=\text{FWT}(h_T)

考虑 c_s[x^p]H_T 的贡献。根据 FWT 的线性性可以得到,[x^p]H_T=\sum\limits_{i=1}^n (-1)^{\text{popcount}(p\land z_{i,T})}。又因为 \text{popcount}(i \land (j \oplus k)) \equiv \text{popcount}(i \land j)+\text{popcount}(i \land k) \pmod 2(这个性质可以拆位理解),所以 \sum\limits_{i=1}^n(-1)^{\text{popcount}(p\land z_{i,T})}=\sum\limits_{i=1}^n\prod\limits_{j\in T} (-1)^{\text{popcount}(p \land a_{i,j})}

于是我们可以把 T 看成一个二进制数 t,那么对于 j \in [0,m),只有当 2^j \subseteq s2^j \subseteq t 时,c_s 会对 [x^p] H_T 造成 -1 倍的贡献,因此 c_s 一共会对 [x^p] H_T 造成 (-1)^{\text{popcount}(s \land t)} 倍的贡献,得到:

[x^p]H_T=\sum_{s=0}^{2^m-1} (-1)^{\text{popcount}(s \land t)}c_s

容易发现这其实就是对 c 做 FWT 后 x^T 的系数。于是我们只需要对于每一个 T 求出 C_T=H_{T,p},做一个 IFWT 得到 c,快速幂计算出 [x^p]G,最后再 IFWT 得到 g 即可。

时间复杂度 \mathcal O(nm2^m+(m+k)2^{m+k})

const int N=1e5+5,K=1<<17,P=3,M=1<<3,mod=998244353,inv=(mod+1)/2;
int n,m,k,kk,d[P],a[N][P],z[M][N],c[N],h[M][K],g[K],w[M];
int ad(int a,int b){
    a+=b;
    if(a>=mod) a-=mod;
    return a;
}
void add(int &a,int b){
    a+=b;
    if(a>=mod) a-=mod;
}
int ksm(int a,int b){
    int res=1;
    while(b){
        if(b&1) res=1ll*res*a%mod;
        a=1ll*a*a%mod;
        b>>=1; 
    }
    return res;
}
void FWT(int *f,int V,int c){
    for(int i=1;i<V;i<<=1){
        for(int j=0;j<V;j++){
            if(j&i) continue;
            int x=f[j],y=f[j|i];
            f[j]=ad(x,y),f[j|i]=ad(x,mod-y);
        }
    }
    if(c!=1) for(int i=0;i<V;i++) f[i]=1ll*f[i]*c%mod;
}
void solve(){
    cin>>n>>kk,m=3,k=17;
    for(int i=0;i<m;i++) cin>>d[i],d[i]%=mod;
    for(int i=0;i<M;i++){
        for(int j=0;j<m;j++){
            if(i&(1<<j)) add(w[i],mod-d[j]);
            else add(w[i],d[j]);
        }
    }
    for(int i=1;i<=n;i++) for(int j=0;j<m;j++) cin>>a[i][j];
    for(int t=0;t<M;t++){
        for(int i=1;i<=n;i++) for(int j=0;j<m;j++) if(t&(1<<j)) z[t][i]^=a[i][j];
        for(int i=1;i<=n;i++) h[t][z[t][i]]++;
        FWT(h[t],K,1);
    }
    for(int p=0;p<K;p++){
        for(int t=0;t<M;t++) c[t]=h[t][p];
        FWT(c,M,ksm(inv,m));
        g[p]=1;
        for(int i=0;i<M;i++) g[p]=1ll*g[p]*ksm(w[i],c[i])%mod;
    }
    FWT(g,K,ksm(inv,k));
    for(int i=0;i<(1<<kk);i++) cout<<g[i]<<' ';
}