xht37 的洛谷博客

xht37's blog:https://www.xht37.com/

题解 P4717 【【模板】快速沃尔什变换】

请在博客中查看

核心思想

记对序列 $a$ 进行快速沃尔什变换后的序列为 $fwt[a]$。

已知序列 $a,b$,求一个新序列 $c = a \cdot b$,直接计算是 $\mathcal O(n^2)$ 的。

若 $a \to fwt[a]$ 和 $b \to fwt[b]$ 是 $\mathcal O(n \log n)$ 的,而 $fwt[c] = fwt[a] \cdot fwt[b]$ 是 $\mathcal O(n)$ 的,同时 $fwt[c] \to c$ 也是 $\mathcal O(n \log n)$ 的。

那么我们可以利用上述过程 $\mathcal O(n \log n)$ 求出 $c$。

OI 中的运用

在 OI 中,FWT 是用于解决对下标进行位运算卷积问题的方法。$$ c_{i}=\sum_{i=j \oplus k} a_{j} b_{k} $$ 其中 $\oplus$ 是二元位运算中的一种。

要求

$$ c_{i}=\sum_{i=j | k} a_{j} b_{k} $$

显然有 $j|i = i, k|i=i \to (j|k)|i = i$。

构造 $fwt[a]_i = \sum_{j|i=i} a_j$。

则有

$$ \begin{aligned} fwt[a] \times fwt[b] &= \left(\sum_{j|i=i} a_j\right)\left(\sum_{k|i=i} b_k\right) \\\\ &= \sum_{j|i=i} \sum_{k|i=i} a_jb_k \\\\ &= \sum_{(j|k)|i = i} a_jb_k \\\\ &= fwt[c] \end{aligned} $$

$a \to fwt[a]$

要求$$ fwt[a]_i = \sum_{j|i=i} a_j $$ 令 $a_0$ 表示 $a$ 中下标最高位为 $0$ 的那部分序列,$a_1$ 表示 $a$ 中下标最高位为 $1$ 的那部分序列。

则有$$ fwt[a] = \text{merge}(fwt[a_0], fwt[a_0] + fwt[a_1]) $$ 其中 $\text{merge}$ 表示「拼接」,$+$ 表示对应位置相加。

于是可以分治。

inline void OR(modint *f) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j+k] += f[i+j];
}

$fwt[a] \to a$

由$$ fwt[a] = \text{merge}(fwt[a_0], fwt[a_0] + fwt[a_1]) $$ 可得$$ a = \text{merge}(a_0, a_1 - a_0) $$

inline void IOR(modint *f) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j+k] -= f[i+j];
}

显然两份代码可以合并。

inline void OR(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j+k] += f[i+j] * x;
}

同理或。

inline void AND(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j] += f[i+j+k] * x;
}

异或

定义 $x\otimes y=\text{popcount}(x \& y) \bmod 2$,其中 $\text{popcount}$ 表示「二进制下 $1$ 的个数」。

满足 $(i \otimes j) \operatorname{xor} (i \otimes k) = i \otimes (j \operatorname{xor} k)$。

构造 $fwt[a]_i = \sum_{i\otimes j = 0} a_j - \sum_{i\otimes j = 1} a_j$。

则有

$$ \begin{aligned} fwt[a] \times fwt[b] &= \left(\sum_{i\otimes j = 0} a_j - \sum_{i\otimes j = 1} a_j\right)\left(\sum_{i\otimes k = 0} b_k - \sum_{i\otimes k = 1} b_k\right) \\ &=\left(\sum_{i\otimes j=0}a_j\right)\left(\sum_{i\otimes k=0}b_k\right)-\left(\sum_{i\otimes j=0}a_j\right)\left(\sum_{i\otimes k=1}b_k\right)-\left(\sum_{i\otimes j=1}a_j\right)\left(\sum_{i\otimes k=0}b_k\right)+\left(\sum_{i\otimes j=1}a_j\right)\left(\sum_{i\otimes k=1}b_k\right) \\ &=\sum_{i\otimes(j \operatorname{xor} k)=0}a_jb_k-\sum_{i\otimes(j\operatorname{xor} k)=1}a_jb_k \\ &= fwt[c] \end{aligned} $$

因此

$$ \begin{aligned} fwt[a] &= \text{merge}(fwt[a_0] + fwt[a_1], fwt[a_0] - fwt[a_1]) \\\\ a &= \text{merge}(\frac{a_0 + a_1}2, \frac{a_0 - a_1}2) \end{aligned} $$

inline void XOR(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j] += f[i+j+k],
                f[i+j+k] = f[i+j] - f[i+j+k] - f[i+j+k],
                f[i+j] *= x, f[i+j+k] *= x;
}

【模板】P4717 【模板】快速沃尔什变换

const int N = 1 << 17 | 1;
int n, m;
modint A[N], B[N], a[N], b[N];

inline void in() {
    for (int i = 0; i < n; i++) a[i] = A[i], b[i] = B[i];
}

inline void get() {
    for (int i = 0; i < n; i++) a[i] *= b[i];
}

inline void out() {
    for (int i = 0; i < n; i++) print(a[i], " \n"[i==n-1]);
}

inline void OR(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j+k] += f[i+j] * x;
}

inline void AND(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j] += f[i+j+k] * x;
}

inline void XOR(modint *f, modint x = 1) {
    for (int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for (int i = 0; i < n; i += o)
            for (int j = 0; j < k; j++)
                f[i+j] += f[i+j+k],
                f[i+j+k] = f[i+j] - f[i+j+k] - f[i+j+k],
                f[i+j] *= x, f[i+j+k] *= x;
}

int main() {
    rd(m), n = 1 << m;
    for (int i = 0; i < n; i++) rd(A[i]);
    for (int i = 0; i < n; i++) rd(B[i]);
    in(), OR(a), OR(b), get(), OR(a, P - 1), out();
    in(), AND(a), AND(b), get(), AND(a, P - 1), out();
    in(), XOR(a), XOR(b), get(), XOR(a, (modint)1 / 2), out();
    return 0;
}

2019-12-11 03:10:04 in 题解