题解:P8315 [COCI 2021/2022 #4] Šarenlist

· · 题解

我们注意到钦定一条路径不合法是容易的,只须要求整条路径上颜色相同即可,共 k 种情况。故我们考虑容斥,每次钦定 S 中的路径不合法,答案即为

\sum_{S} (-1)^{|S|}ans_S

考虑如何计算 ans_S。首先,有一些边不在钦定不合法的路径中,可以在 k 种颜色中任意选择。随后,我们发现:如果两条路径相交至少一条边,那么这两条路径一定颜色相同。具体地,如果共有 x 条边在钦定的路径中,路径共形成了 y 个连通块,那就有 ans_S = k^{n - 1 - x} \times k^{y}

我们用 bitset 记录下每条路径上的边,将 S 中的 bitset 并起来的大小就是 x。我们再维护 con_{x, y} 表示 x 路径与 y 路径是否有交,在枚举路径 i 加入 S 时,枚举已经加入的 j,如果 con_{i, j} = 1 就在并查集上将 ij 合并,最终并查集的集合数量就是 y

时间复杂度 \mathcal{O}(m^2n + m2^m(\frac{n}{\omega} + m\alpha(m)))

Code:

#include<bits/stdc++.h>
#define mem(a, v) memset(a, v, sizeof(a))

using namespace std;

const int maxn = 60 + 10, maxm = 15 + 10, mod = 1e9 + 7;

struct{
    int v, nex;
} edge[maxn << 1];

int n, m, k, res = 0;
int dep[maxn], fat[maxn];
bool con[maxm][maxm];
bitset<maxn> pth[maxm], now;
int head[maxn], top = 0;

inline void add(int u, int v, bool o = true){
    edge[++top].v = v, edge[top].nex = head[u], head[u] = top, o && (add(v, u, false), 1538);
}

inline void dfs(int u, int fa){
    dep[u] = dep[fa] + 1, fat[u] = fa;
    for (int i = head[u]; i; i = edge[i].nex){
        const int v = edge[i].v;
        v != fa && (dfs(v, u), 1538);
    }
}

inline int lca(int x, int y, int p){
    dep[x] < dep[y] && (swap(x, y), 1538);
    for (;dep[x] > dep[y]; pth[p].set(x), x = fat[x]);
    if (x == y){
        return x;
    }
    for (;x != y; pth[p].set(x), pth[p].set(y), x = fat[x], y = fat[y]);
    return x;
}

inline long long ksm(long long a, long long b){
    long long res = 1;
    for (;b; b & 1 && (res = res * a % mod), a = a * a % mod, b >>= 1);
    return res;
}

template<typename Tp_x, typename Tp_y>
inline int mod_add(Tp_x &x, Tp_y y){
    return x += y, x >= mod ? x -= mod : x;
}

namespace DSU{
    int fa[maxm];
    inline int getf(int x){
        return fa[x] == x ? x : fa[x] = getf(fa[x]);
    }
    inline void merge(int x, int y){
        (x = getf(x)) != (y = getf(y)) && (fa[x] = y);
    }
}

using namespace DSU;

int main(){
    scanf("%d %d %d", &n, &m, &k);
    for (int i = 1; i < n; i++){
        int u, v;
        scanf("%d %d", &u, &v), add(u, v);
    }
    dfs(1, 0);
    for (int i = 1; i <= m; i++){
        int x, y;
        scanf("%d %d", &x, &y), lca(x, y, i);
        for (int j = 1; j < i; j++){
            for (int k = 1; k <= n; k++){
                con[i][j] |= pth[i].test(k) && pth[j].test(k);
            }
        }
    }
    for (int sta = 0; sta < 1 << m; sta++){
        bool cnt = false;
        iota(fa + 1, fa + m + 1, 1), now.reset();
        for (int i = 1; i <= m; i++){
            if (sta >> i - 1 & 1){
                cnt ^= 1, now |= pth[i];
                for (int j = 1; j < i; j++){
                    if (sta >> j - 1 & 1 && con[i][j]){
                        merge(i, j);
                    }
                }
            }
        }
        int sum = 0;
        for (int i = 1; i <= m; i++){
            sum += sta >> i - 1 & 1 && fa[i] == i;
        }
        const int val = ksm(k, n - 1 - now.count() + sum);
        mod_add(res, cnt ? mod - val : val);
    }
    printf("%d", res);

return 0;
}