P11458 Solution
下文先给出一个传统的 FWT 视角做法,再给出一个使用转置原理理解下的做法。
令
首先,可以想到一个
- 不妨对每个
0 \leq s \leq k ,求出有多少个j 满足P(a_i\mid a_j) = s 。我们以类似的手法求出\sum P(a_j)[P(a_i\mid a_j) = s] 。 - 令
b_k = \sum_j [a_j = k] ,对b 做高维后缀和,保留\operatorname{popcount} 恰好为s 的元素后,再做高维前缀和,在a_i 处统计对应的答案。 - 这个做法会算重,算重的原因在于我们可能实际上枚举到了
a_i\mid a_j 的超集,但对每种a_i 再做一次二项式反演即可正确处理所有贡献。
为了优化复杂度至
上一步使用了一个经典的容斥技巧:
对其交换求和顺序即得。
此时,位运算卷积技巧已经足以帮助我们解决这个问题:使用高维后缀和,对每个
int main() {
int n = read<int>(), k = read<int>(), z = (1 << k);
for (int i = 0; i < z; ++i)
pop[i] = __builtin_popcount(i);
for (int j = 1; j <= k; ++j)
iv[j] = modint(j).inverse();
for (int i = 1; i <= n; ++i) {
x[i] = read<int>();
c[x[i]] += 1, d[x[i]] += pop[x[i]];
}
for (int j = 1; j < z; ++j)
w[j] = modint(pop[j] & 1 ? -1 : 1) * iv[pop[j]];
for (int i = 0; i < k; ++i)
for (int j = 0; j < z; ++j) if ((j >> i) & 1)
w[j ^ (1 << i)] += w[j];
for (int j = 0; j < z; ++j)
w[j] *= modint(pop[j] & 1 ? -1 : 1);
auto process = [&](modint c[], modint w[]) -> void {
for (int i = 0; i < k; ++i)
for (int j = 0; j < z; ++j) if ((j >> i) & 1)
c[j] += c[j ^ (1 << i)];
for (int j = 0; j < z; ++j)
c[j] *= w[j];
for (int i = 0; i < k; ++i)
for (int j = 0; j < z; ++j) if ((j >> i) & 1)
c[j ^ (1 << i)] += c[j];
};
process(c, w), process(d, w);
for (int i = 1; i <= n; ++i)
print<int>((modint(pop[x[i]]) * c[x[i]] + d[x[i]] - n).get(), '\n');
return 0;
}
考虑我们所求为一向量
这种理解的本质是:在一个线性问题里,如果计算输出向量的与任意向量的内积结果是可行的,则可以相同的时间复杂度解决原问题。
int main() {
int n = read<int>(), k = read<int>(), z = (1 << k);
for (int j = 0; j < z; ++j)
pop[j] = __builtin_popcount(j), inv[j] = pop[j].inverse();
for (int i = 1; i <= n; ++i)
a[i] = read<int>(), buc[a[i]] += 1, bucp[a[i]] += pop[a[i]];
auto conv = [&](modint res[], modint a[], modint b[]) -> void {
for (int j = 0; j < k; ++j) for (int i = 0; i < z; ++i)
if ((i >> j) & 1) a[i] += a[i ^ (1 << j)];
for (int i = 0; i < z; ++i) bb[i] = b[i];
for (int j = 0; j < k; ++j) for (int i = 0; i < z; ++i)
if (!((i >> j) & 1)) bb[i] -= bb[i ^ (1 << j)];
for (int i = 0; i < z; ++i)
res[i] = a[i] * bb[i];
for (int j = 0; j < k; ++j) for (int i = 0; i < z; ++i)
if (!((i >> j) & 1)) res[i] += res[i ^ (1 << j)];
};
conv(c, buc, inv), conv(d, bucp, inv);
for (int i = 1; i <= n; ++i)
print<int>((pop[a[i]] * c[a[i]] + d[a[i]] - n).get(), '\n');
return 0;
}