CF1119H Triple

· · 题解

尝试解决更为一般化的问题:给定 w_0,w_2 \ldots w_{k-1}a_{i,j} (1 \leq i \leq n,0 \leq j < k),求 \prod_{i=1}^n\sum_{j=0}^{k-1} w_j x^{a_{i,j}}。 本题即 k = 3 的特例。

考虑将 黎明前的巧克力 一题的做法扩展到一般情况。

此时,一个多项式的 FWT 数组的每一项有 2^k 种可能的取值,即:\sum s_i w_i ,其中 s_i = \pm 1。我们想要对于每一项,求这 n 个 FWT 数组的对应项中,这 2^k 种取值分别出现了多少次。以下讨论均针对 FWT 数组的某一项。

c_0,c_1 \ldots c_{2^k - 1} 分别为这 2^k 种取值的出现次数。为了方便,对于 c_x 对应的系数数组 s,我们把 x 看作长为 k 的二进制数,若其第 i 位为 1s_i = -1,否则 s_i = 1。显然这样可以和 s 一一对应。

考虑对于 \{0,1, \ldots k-1\} 的一个子集 T,令 z_i = \bigoplus_{j \in T} a_{i,j},求 \sum_{i=1}^n x^{z_i} 的 FWT 数组,它的这一项是什么样子。

考虑 |T| = 1 时,当然是 \sum \pm c_i 这样的形式。具体地,如果 T 中唯一的元素为 j,那么对于二进制下第 j 位为 1ic_i 前面的系数是 -1,否则是 1。(相当于你只保留一项的系数做 FWT,看这一项的系数的影响)。

|T| > 1 时,那其实 x^{z_i} = \prod_{j \in T} x^{a_{i,j}},FWT 数组的这一项即为所有元素各自的 FWT 数组(上面一种情况)这一项的乘积。因此这时候 c_i 前面的系数依然是 \pm 1,具体地,可以发现,若把 T 也写成长为 k 的二进制数 tc_i 前面的系数实际上是 (-1)^{|t \operatorname{and} i|}。(只有都是 1 的位置会乘一个 -1

敏锐的读者此时应该已经意识到了,这正是 FWT 的系数。即,若求出 H_t 表示对于二进制表示为 t 的集合 T\sum_{i=1}^n x^{z_i} 的 FWT 数组的这一项,令人意外地, H 就是 c_0,c_1 \ldots c_{2^k - 1} 进行 FWT 之后得到的数组。现在我们要求 c,只需把 H 逆变换回去就行了。总复杂度 O(n2^k + (m + k)2^{m + k})

下面是我的代码实现。将 k 设为 23,略加修改输入格式即可直接通过上面提到的黎明前的巧克力和本题。

#include <bits/stdc++.h>

using ll = long long;
const int mod = 998244353;
const int inv2 = (mod + 1) / 2;
const int maxn = (1 << 20) + 12;

int n,m,k,w[maxn],z[maxn][12],F[maxn],G[maxn],W[maxn],sum[maxn],lg[maxn];

int qpow(int a,ll b) {
    if (b == 0) return 1;
    ll d = qpow(a,b>>1); d = d * d % mod;
    if (b & 1) d = d * a % mod;
    return d;
}

void DFT(int *a,int lim,int flag){
    for (int i = 1; i < lim; i <<= 1)
    for (int j = 0; j < lim; j += (i << 1))
    for (int k = 0; k < i; ++ k) {
        ll A0 = a[j+k], A1 = a[j+k+i];
        a[j+k] = (A0 + A1) % mod; a[j+k+i] = (A0 - A1 + mod) % mod;
        if (flag == -1) { 
            a[j+k] = (ll) a[j+k] * inv2 % mod;
            a[j+k+i] = (ll) a[j+k+i] * inv2 % mod;
        }
    }
}

int main() {
    scanf("%d%d%d",&n,&m,&k);
    int H[(1<<k)+10][(1<<m)+10];
    std::memset(H,0,sizeof(H));
    for (int i = 0; i < k; ++ i) scanf("%d",&w[i]);
    lg[0] = -1;
    for (int i = 0; i < (1 << k); ++ i) {
        if (i) lg[i] = lg[i/2] + 1;
        for (int j = 0; j < k; ++ j) {
            if (i & (1 << j)) W[i] = (W[i] - w[j] + mod) % mod;
            else W[i] = (W[i] + w[j]) % mod;
        }
    }
    for (int i = 1; i <= n; ++ i) {
        for (int j = 0; j < k; ++ j) 
            scanf("%d",&z[i][j]); 
        H[0][0] += 1;
        for (int mask = 1; mask < (1 << k); ++ mask) {
            int lowbit = mask & -mask;
            int b = lg[lowbit];
            sum[mask] = sum[mask - lowbit] ^ z[i][b];
            H[mask][sum[mask]] += 1;
        }
    }
    for (int mask = 0; mask < (1 << k); ++ mask)
        DFT(H[mask],(1<<m),1);
    for (int i = 0; i < (1 << m); ++ i) {
        for (int j = 0; j < (1 << k); ++ j) 
            G[j] = H[j][i];
        DFT(G,(1<<k),-1);
        F[i] = 1;
        for (int j = 0; j < (1 << k); ++ j) 
            F[i] = (ll) F[i] * qpow(W[j],G[j]) % mod;
    } DFT(F,(1<<m),-1);
    for (int i = 0; i < (1 << m); ++ i)
        printf("%d ",F[i]);
    return 0;
}