快速莫比乌斯/沃尔什变换 (FMT/FWT)

· · 题解

基本概念

定义

本质就是全集的各个子集到值域的映射。

形式化地来说就是,域 F 上的集合幂级数是 2^U\to F 的函数,对于每个 S\subseteq U,都有 f_s \in F

从多项式的角度理解,就是把之前的下标为某个数字改为为某个具体集合。

表示

我们用 cx^s 表示 f_s=c,于是集合幂级数就可以表示为 f=\sum\limits_{S\in U}f_sx^s

加法与乘法

加法直接对应系数相加即可,注意此时 fg 必须全集相同。

乘法 h=f\times g=(\sum\limits_{L\in 2^U}f_Lx^L)\times(\sum\limits_{R\in 2^U}g_Rx^R)。当集合幂级数的乘法对加法有分配律的时候有,

h=f\times g=\sum\limits_{L\in 2^U}\sum\limits_{R\in 2^U}f_Lx^L\times g_Rx^R

我们设 f_Lx^L\times g_Rx^R=(f_L\times g_R)x^{L*R},其中 \times 就是 F 域乘法,* 就是集合域运算,且满足结合律和交换律

根据 * 的不同有不同算法,以下一一介绍。

集合并卷积 / OR 卷积

就是把上述 * 定义为 \bigcup,然后我们就可以得到 h_s=\sum\limits_{L\in 2^U}\sum\limits_{R\in 2^U}[L\cup R=S]f_Lg_R 暴力计算时间复杂度 O(4^n)

莫比乌斯变换(FMT)

定义快速莫比乌斯变换,f'={\rm FMT}(f),为 f'_i=\sum\limits_{j\subset i}f_j,也就是子集求和。

定义快速莫比乌斯反演,f'={\rm FMI}(f),为 f'_i=\sum\limits_{j|i=i}(-1)^{|i|-|j|}f_j

其中 FMT 和 FMI 互为逆变换。对于 OR 卷积,我们只需要把 fg 做一下 FMT,然后 h_i\gets f_i\times g_i,最后对于 h 做一个 FMI 就解决了。

莫比乌斯变换可以通过高维前缀和来实现,所以我们对于 FMT 就是做一个系数为 1 的高维前缀和,对于 FMI 就是做一个系数为 -1 的高维前缀和。

莫比乌斯变换的本质是容斥原理,原先 L\cup R=S 这个限制条件使得我们难以拆开 fg,因为二者互相关联。考虑使用容斥原理,我们钦定 S 中为 0 的位置,在 L\cup R 的结果中也为 0。于是我们就把限制宽松到了两个子集,只需要满足 L\subset SR\subset S 即可,这就可以通过做一次高维前缀和(FMT),然后对应位置相乘就行了。你钦定完之后,还需要设计容斥系数(FMI),系数是 (-1)^{(n-|L|)-(n-|S|)}=(-1)^{|S|-|L|}

沃尔什变换(FWT)

对于卷积 h\gets f\times g,我们按照二进制最高位的奇偶性分为 f_0,f_1,g_0,g_1,h_0,h_1

h_0=f_0\times g_0 h_1=f_1\times g_0+f_0\times g_1+f_1\times g_1=(f_0+f_1)(g_0+g_1)-f_0g_0

发现以上都可以通过 (f_0+f_1)\times (g_0+g_1)f_0\times g_0 的组合求解出来。

进行如下变换,

(f_0,f_1,g_0,g_1)\to (f_0,f_0+f_1,g_0,g_0+g_1)

变换之后,递归求 I_0=f'_0\times g'_0I_1=f'_1\times g'_1 即可。递归到底层的时候就是对应位置直接乘就行了。回溯的时候进行如下变换,

(h_0,h_1)\gets (I_0,I_1-I_0)

很容易写出一个分治的形式,但是常数很大,我们将其改成非递归形式。

容易发现其实往下递归的时候是 f_0 不变,f_1\gets f_0+f_1,往上回溯的时候也是 f_0 不变,f_1\gets f_1-f_0。我们可以将这两部分的代码,传入一个系数 t=1 或者 t=mod-1

为了模拟递归,我们先要枚举递归层数,也就是正在处理的这一位 2^k,然后分开枚举 >2^k<2^k 的值就行了。拼在一起就是当前正在做的数字。

for(int k=1;k<full;k<<=1)
    for(int i=0;i<full;i+=(k<<1))
        for(int j=0;j<k;j++) 
            add(f[i|j|k],1ll*f[i|j]*t%mod);

集合交卷积 / AND 卷积

就是把上述 * 定义为 \bigcap,然后我们就可以得到 h_s=\sum\limits_{L\in 2^U}\sum\limits_{R\in 2^U}[L\cap R=S]f_Lg_R

不管是理解还是推导方面,都和 OR 卷积是一样的,就不过多赘述了。

莫比乌斯变换(FMT)

把 OR 卷积中系数为 +1-1 的高维前缀和,改为系数为 +1-1 的高维后缀和即可。

容斥的时候钦定 L\cap R 中为 1 的位置是 S 即可。

沃尔什变换(FWT)

也是几乎和 OR 卷积一模一样。

先进行 FWT,

(f_0,f_1,g_0,g_1)\to (f_0+f_1,f_1,g_0+g_1,g_1)

再进行 IFWT,

(h_0,h_1)\gets (I_0-I_1,I_1)

集合对称差卷积 / XOR 卷积

就是把上述 * 定义为 \bigoplus,然后我们就可以得到

h_s=\sum\limits_{L\in 2^U}\sum\limits_{R\in 2^U}[L\bigoplus R=S]f_Lg_R

暴力计算时间复杂度 O(4^n)。关于 XOR 卷积就没有 FMT 了,只能进行 FWT。

沃尔什变换(FWT)

定义沃尔什变换 f'={\rm FWT(f)},为 f'_s=\sum\limits_{T\in 2^U}f_T(-1)^{\lvert S\cap T\rvert}

定义沃尔什逆变换,f'={\rm IFWT(f)},为 f'_s=\dfrac{1}{2^n}\sum\limits_{T\in 2^U}f_T(-1)^{\lvert S\cap T\rvert}

对于卷积 h\gets f\times g,我们按照二进制最高位的奇偶性分为 f_0,f_1,g_0,g_1,h_0,h_1

h_0=f_1\times g_1+f_0\times g_0 h_1=f_1\times g_0+f_0\times g_1

发现以上都可以通过 (f_0+f_1)\times (g_0+g_1)(f_0-f_1)\times (g_0-g_1) 组合出来。

进行如下变换,

(f_0,f_1,g_0,g_1)\to (f_0+f_1,f_0-f_1,g_0+g_1,g_0-g_1)

变换之后,递归求 I_0=f'_0\times g'_0I_1=f'_1\times g'_1 即可。递归到底层的时候就是对应位置直接乘就行了。回溯的时候进行如下变换,

(h_0,h_1)\gets (\dfrac{I_0+I_1}{2},\dfrac{I_0-I_1}{2})

将上面这个递归形式改成非递归形式就行了。\dfrac{1}{2} 提取出来放到最后乘,每一位都要乘以 \dfrac{1}{2},所以总计是乘以 \dfrac{1}{2^n}

void FWT(int *f,bool type){
    for(int k=1;k<full;k<<=1)
        for(int i=0;i<full;i+=(k<<1))
            for(int j=0;j<k;j++){
                int x=f[i|j],y=f[i|j|k];
                f[i|j]=(x+y)%mod; f[i|j|k]=(x-y+mod)%mod;
            }
    if(!type) return ;
    for(int i=0;i<full;i++) f[i]=1ll*f[i]*invp%mod;
}

线性代数角度的 FWT

我们设

f'_s=\sum\limits_{T\in 2^U}c_{s,t}f_T

由于变换后对应位置相乘需要满足 * 运算,所以 c_{i,j*k}=c_{i,j}c_{i,k}。还是和之前拆位思想相同。我们对于每一位独立进行上述变化。此时只需要考虑 s,t 的某一个二进制位,所以下标取值只有 0/1,故这是一个 2\times 2 的矩阵。

按位或矩阵

\begin{bmatrix}1&0\\ 1&1 \end{bmatrix}

其逆矩阵为

\begin{bmatrix}1&0\\ -1&1\end{bmatrix}

按位与矩阵

\begin{bmatrix}1&1\\0&1\end{bmatrix}

其逆矩阵为

\begin{bmatrix}1&-1\\0&1\end{bmatrix}

按位异或矩阵

\begin{bmatrix}1&1\\0&-1\end{bmatrix}

其逆矩阵为(为了减少常数,可以把 \frac{1}{2} 提取出来,最后再乘)

\begin{bmatrix}0.5&0.5\\0.5&-0.5\end{bmatrix}

其实很好理解,根据之前的推导,OR 卷积和 AND 卷积的变换应该是要求一个是对于子集求和,一个是对于超集求和。或矩阵中的 10,01,11\to 100\to 0 恰好对应子集求和的形式。与矩阵同理。异或矩阵也符合我们推到的 FWT 的系数。

从这个角度,我们可以看出 FWT 是一种线性变换。

扩展 三进制 MEX 卷积

这一部分旨在启发构造一些变式位运算卷积的方法。

定义三进制 \operatorname{mex} 运算,先将两个数字 a,b 进行三进制拆分,形如 a=\sum a_i3^i\operatorname{mex}(a,b)=\sum\limits_{i=0}^{k-1}\operatorname{mex}(a_i,b_i)。也就是说把每个数都拆成三进制表示,然后对于每个三进制位进行 \operatorname{mex}。 定义三进制 \operatorname{mex} 卷积为

h_s=\sum\limits_{\operatorname{mex}(L,R)=S}f_L\times g_S

给定长度为 3^{n}fg,求解 hn\le 10

还是按照 FWT 那套方法,对于当前的最高位的值进行分类。

c_0=a_1b_1+a_1b_2+a_2b_1+a_2b_2=(a_1+a_2)(b_1+b_2) c_1=a_0b_0+a_0b_2+a_2b_0=(a_0+a_2)(b_0+b_2)-a_2b_2 c_2=(a_0+a_1+a_2)(b_0+b_1+b_2)-c_0-c_1

发现我们需要构造四种乘法,

(a_0,a_1,a_2)\to (a_1+a_2,a_0+a_2,a_2+a_0+a_1+a_2)

这样子递归是 3^n\to 4^n,因为三项变四项。

逆变换,令上述分别为 A,B,C,D(A,B,C,D)\to (A,B-C,D-A-B-C)

可以递归计算,时间复杂度是 O(n4^n)

Code

代码中的 OR 卷积和 AND 卷积都是用 FMT 实现,XOR 卷积用 FWT 实现。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=998244353;
const int maxn=19;
const int maxs=(1<<19);
int n,full,a[maxs],b[maxs],invp;
void add(int &x,int y){ x=x+y>=mod?x+y-mod:x+y; }
void sub(int &x,int y){ x=x-y<0?x-y+mod:x-y; }
void FMT(int* f,int t){
    for(int i=1;i<full;i<<=1)
        for(int j=0;j<full;j++)
            if(i&j) add(f[j],1ll*t*f[i^j]%mod);
}
void FMT2(int* f,int t){
    for(int i=1;i<full;i<<=1)
        for(int j=0;j<full;j++)
            if(!(i&j)) add(f[j],1ll*t*f[i^j]%mod);
}
void FWT(int *f,bool type){
    for(int k=1;k<full;k<<=1)
        for(int i=0;i<full;i+=(k<<1))
            for(int j=0;j<k;j++){
                int x=f[i|j],y=f[i|j|k];
                f[i|j]=(x+y)%mod; f[i|j|k]=(x-y+mod)%mod;
            }
    if(!type) return ;
    for(int i=0;i<full;i++) f[i]=1ll*f[i]*invp%mod;
}
void print(int *H){
    for(int i=0;i<full;i++) cout<<H[i]<<" ";
    cout<<endl;
}
int A[maxs],B[maxs],C[maxs];
void or_conv(){
    memcpy(A,a,sizeof(a)); memcpy(B,b,sizeof(b));
    FMT(A,1); FMT(B,1);
    for(int i=0;i<full;i++) C[i]=1ll*A[i]*B[i]%mod;
    FMT(C,mod-1); print(C);
}
void and_conv(){
    memcpy(A,a,sizeof(a)); memcpy(B,b,sizeof(b));
    FMT2(A,1); FMT2(B,1);
    for(int i=0;i<full;i++) C[i]=1ll*A[i]*B[i]%mod;
    FMT2(C,mod-1); print(C);
}
void xor_conv(){
    memcpy(A,a,sizeof(a)); memcpy(B,b,sizeof(b));
    FWT(A,0); FWT(B,0);
    for(int i=0;i<full;i++) C[i]=1ll*A[i]*B[i]%mod;
    invp=1; for(int i=1;i<=n;i++) invp=1ll*invp*((mod+1)>>1)%mod;
    FWT(C,1); print(C);
}
int main(){
    cin>>n; full=(1<<n);
    for(int i=0;i<full;i++) cin>>a[i];
    for(int i=0;i<full;i++) cin>>b[i];
    or_conv(); and_conv(); xor_conv();
    return 0;
}