哈集幂!

· · 题解

挺神妙的一个题吧。自己做的时候完全没做明白,看题解也看到晕晕乎乎的。今天重新从哈集幂的视角理解了一下,才发现确实挺妙(而且典)。

面对这个题首先做一些小转化。考虑一条边,它在哪几条路径上?我们用一个掩码(k 位二进制数)表示这个集合。第 i 位为 1 代表这条边在 a_ib_i 间的简单路径上。

我们考虑给树赋一个根,然后把每条边和它连接的两个点中深度较大的对应起来。一条路径经过这条边,当且仅当一个端点在这个点的子树中,而另一个端点不在子树中。这样,我们对每个点维护一个掩码,第 i 位代表第 i 条路径在该点子树中的端点的数目的奇偶性。不难发现这个掩码其实和这个点对应的边的掩码相等。

容易预处理这个掩码。具体的,考虑对第 i 条路径的一个端点 a\text{mask}_a = \text{mask}_a \text{ xor } 2^i。然后我们做一次 dfs,对每个点,先处理它的所有子节点,然后它的掩码异或上它所有子节点的掩码的异或和。

然后,我们就注意到,挑选一些边涂色,这时被满足(即有至少一条边被涂色)的简单路径的集合就是这些边对应的掩码(集合)的并集(或者按位或)。

至此,我们就完成了所有平凡的预处理工作。问题被转化为:

给定若干集合,求最少需要选取多少集合才能使得它们的并集为全集。

这个问题是本题的核心,也是一个非常经典的 NP-hard 问题。

我们考虑将最优化问题转化为判定问题:从小到大枚举需要多少集合,再判定这是否可行。

注意到答案最大等于全集大小。因为每个集合至少贡献一位,否则直接去除它更好。

进一步转化为计数问题:选出 x 个集合,有多少种方案使得它们的并集为全集?

原题在这里要求选出的集合互不相同,但选出重复集合并不影响。

考虑哈集幂。我们构造一个长度为 2^k 的序列 a_i(数组),其中第 s 个位置是选出 i 个集合,其并集是掩码为 s 的集合的方案数。

显然有 a_i = (a_1)^i,这里的乘法是并集卷积,即 (\text{or}, \times) 卷积。意思是,a\times b 的第 k 位是所有满足 i \text{ or } j = ka_i \times b_j 的和。

然后套用 SOS dp 和高维前缀和/FWT 模板即可。

放一个 AC 代码,里面涉及到一些实现细节。

#include <bits/stdc++.h>

using namespace std;
int t;
array<uint32_t, 100005> mask;
array<vector<int>, 100005> e;

void dfs(int u, int fa) {
    for (int v: e[u]) {
        if (v != fa) {
            dfs(v, u);
            mask[u] ^= mask[v];
        }
    }
}

template<typename T>
constexpr uint64_t mod(T res) {
    constexpr uint64_t mod_p = (1ull << 61) - 1;
    res = (res >> 61) + (res & mod_p);
    res = (res >> 61) + (res & mod_p);

    if (res == mod_p) {
        return 0;
    }
    return res;
}

uint64_t quick_pow(uint64_t a, uint64_t b) {
    uint64_t res = 1;
    while (b) {
        if (b & 1) {
            res = mod((unsigned __int128) res * a);
        }

        a = mod((unsigned __int128) a * a);
        b >>= 1;
    }
    return res;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> t;
    for (int _ = 0; _ < t; ++_) {
        int n;
        cin >> n;

        for (int i = 1; i <= n; ++i) {
            e[i].clear();
            mask[i] = 0;
        }

        for (int i = 1; i < n; ++i) {
            int u, v;
            cin >> u >> v;
            e[u].push_back(v);
            e[v].push_back(u);
        }

        int k;
        cin >> k;
        for (int i = 0; i < k; ++i) {
            int x, y;
            cin >> x >> y;
            mask[x] ^= (1 << i);
            mask[y] ^= (1 << i);
        }

        dfs(1, 0);

        vector<uint32_t> cnt(1 << k);
        for (int i = 1; i <= n; ++i) {
            cnt[mask[i]]++;
        }

        for (int i = 0; i < k; ++i) {
            for (uint32_t j = 0; j < (1 << k); ++j) {
                if ((j >> i) & 1) {
                    cnt[j] += cnt[j ^ (1 << i)];
                }
            }
        }

        for (int ans = 1; ans <= k; ++ans) {
            uint64_t term = 0;
            for (uint32_t i = 0; i < (1 << k); ++i) {
                const uint64_t tmp = quick_pow(cnt[i], ans);
                term = mod((1ll << 61) - 1 + term - ((k - popcount(i)) & 1 ? tmp : -tmp));
            }

            if (term) {
                cout << ans << "\n";
                break;
            }
        }
    }
    return 0;
}