题解 P3349 [ZJOI2016]小星星

· · 题解

本题解配合我的 容斥原理学习笔记 食用更佳。

给出一张图 G 和一棵树 T,点数都为 n,求有多少个长为 n 的排列 \{a_i\},满足若存在 \langle u,v\rangle \in T 则一定有 \langle a_u,a_v\rangle \in G

题意转化一下,就是给树上每一个点分配一个编号,并且相邻的点有编号的限制,要求最后编号无重复。

其实一开始不太容易想到要容斥,不妨先考虑暴力状压 DP。设 f(u,i,S) 表示 u 点编号为 i,且以 u 为根的子树的编号集合为 S,的方案数。转移就是对几个不相交的集合的合并。显然这样复杂度是爆炸的,但是已经没有优化的余地了。

这时候容斥的作用就显现了,我们可以先放宽限制来降低复杂度,然后容斥来得到最终答案。我们可以瞅准转移时“集合不相交”的限制,但是去掉了这个限制的话,编号会重复,编号集合的大小小于 n

这时候有容斥的味道了。我们发现要求的是编号集合大小等于 n,而现在只能得到编号集合大小小于等于 n 的方案数。于是构造若干个限制:枚举编号集合 S,钦定整棵树的编号集合为 S 的子集。\mathcal O(n^3) DP 一下可以得到 s(S) 表示整棵树编号集合为 S 的子集的方案数。

上套路,设大小为 n 的编号集合为 U,对于一个编号集合恰好为 S 的方案,正确的贡献应该为 [S=U],得到容斥系数计算式

\sum_{T\supseteq S}f(T)=[S=U]

识别为集合形式,子集反演可以得到容斥系数

\begin{aligned} f(S)&=\sum_{T\supseteq S}(-1)^{|T|-|S|}[T=U]\\ &=(-1)^{n-|S|} \end{aligned}

最后答案就是 \sum_{S\subseteq U}s(S)f(T)

于是我们得到了 \mathcal O(n^32^n) 的做法,实际上跑的飞快。

typedef long long LL;

const int N = 20;
const int M = (1 << 17) + 5;

int n, m;
LL ans;
int g[N][N], popcnt[M];
LL f[N][N];
vector<int> e[N], lst;

void dp(int u, int fa) {
    for (auto v: e[u]) {
        if (v == fa) continue;
        dp(v, u);
    }
    for (auto i: lst) {
        f[u][i] = 1;
        for (auto v: e[u]) {
            if (v == fa) continue;
            LL sum = 0;
            for (auto j: lst) {
                if (!g[i][j]) continue;
                sum += f[v][j];
            }
            f[u][i] *= sum;
        }
    }
}

inline void clear() {
    if (lst.empty()) return;
    for (int i = 1; i <= n; ++i)
        for (auto j: lst)
            f[i][j] = 0;
    lst.clear();
}

inline void main() {
    cin >> n >> m;
    for (int i = 1; i <= m; ++i) {
        int u, v;
        cin >> u >> v;
        g[u][v] = g[v][u] = 1;
    }
    for (int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    int lim = 1 << n;
    for (int i = 0; i < lim; ++i) popcnt[i] = popcnt[i >> 1] + (i & 1);
    for (int s = 0; s < lim; ++s) {
        clear();
        for (int i = 1; i <= n; ++i)
            if (s & (1 << (i - 1))) lst.push_back(i);
        dp(1, 0);
        LL sum = 0;
        for (auto i: lst) sum += f[1][i];
        ans += sum * (((n - popcnt[s]) & 1) ? -1 : 1);
    }
    cout << ans << '\n';
}