P8315 [COCI2021-2022#4] Šarenlist 题解

· · 题解

分析

显然,直接求解难,不妨先求不可行的方案,从总数中减去。

对于每一条路径,设它上面有 s 条边,则可以得到有 k^{n-s} 种不可行的方案(因为 s 条边在不可行方案中一定是同色,所以可以合并,加上 n-1-s 条非路径边,共是 k^{n-s} 种方案)。

那么,有一个问题:有些方案会被重复计算。

于是需要 O(2^m) 的时间复杂度去进行容斥。

假设有两条路径,已单独计算。

因为相重叠的路径必定是同色(重复计算的一定是同色方案),所以两条路径可以合并为一个连通块。

此时就可以想到用并查集去维护连通块数量。

加上快速幂,就可以解决本题。

总时间复杂度:O(2^m(mn\alpha(n)+n+\log n))(枚举容斥状态 O(2^m),枚举路径 O(m),合并边 O(n\alpha(n)),计算连通块 O(n),快速幂 O(\log n))。

代码

#include<bits/stdc++.h>
using namespace std;
int n,m,k,a,b,s,t,cnt,sum,c[16],d[16],f[61][10],l[16],dis[61],dep[61],head[61],fat[61];
long long mod=1e9+7,ans,g;
struct node{
    int next,to;
}e[121];//链式前向星
void add(int x,int y){
    e[++cnt].next = head[x];
    e[cnt].to = y;
    head[x] = cnt;
}//建边
void dfs(int x,int fa){
    f[x][0] = fa;
    dep[x] = dep[fa] + 1;
    for(int i = head[x] ; i ; i = e[i].next ){
        int y = e[i].to;
        if(y != fa){
            dis[y] = dis[x]+1;
            dfs(y,x);
        }
    }
}//用于LCA预处理
int lca(int x,int y){
    if(dep[x] < dep[y]) swap(x,y);
    int p = 6;
    while(~p){
        if( dep[f[x][p]] >= dep[y] ) x = f[x][p];
        p--;
    }
    if(x == y) return x;
    p = 6;
    while(~p){
        if(f[x][p] != f[y][p]) x = f[x][p],y = f[y][p];
        p--;
    }
    return f[x][0];
}//最近公共祖先的计算
int find(int x){
    if (fat[x] != x)return fat[x] = find(fat[x]);
    return x;
}//并查集基本操作
long long ksm(int mi){
    long long res = k,ll = 1;
    while(mi){
        if(mi & 1)ll = ll * res % mod;
        res = res * res % mod,mi >>=1;
    }
    return ll;
}//快速幂
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>m>>k;
    for(int i = 1;i < n;i++){
        cin>>a>>b;
        add(a,b);
        add(b,a);
    }
    dfs(1,0);
    for(int i = 1;i <= 6;i++){
        for(int j = 1;j <= n;j++){
            f[j][i] = f[f[j][i-1]][i-1];
        }
    }
    for(int i=1;i<=m;i++){
        cin>>c[i]>>d[i];
        l[i] = lca(c[i],d[i]);
    }
    ans = ksm(n-1);
    for(int i = 1;i < (1<<m);i++){
        g = 1,sum = 0;
        for(int j = 2;j <= n;j++){
            fat[j] = j;
        }//并查集初始化
        for(int j = 0;j < m;j++){
            if(i>>j&1){
                g = -g;//为了计算容斥时的正负性
                s = c[j+1];
                while(s != l[j+1]){
                    if(f[s][0] == l[j+1])break;
                    a = find(s),b = find(f[s][0]);
                    if(a != b)fat[a] = b;//合并边
                    s = f[s][0];
                }//从一段跳到LCA
                t = d[j+1];
                while(t != l[j+1]){
                    if(f[t][0] == l[j+1])break;
                    a = find(t),b = find(f[t][0]);
                    if(a != b)fat[a] = b;
                    t = f[t][0];
                }//从另一端跳到LCA
                if(s != l[j+1] && t != l[j+1]){
                    a = find(s),b = find(t);
                    if(a != b)fat[a] = b;
                }//如果两者都不是LCA,则LCA下路径中的两条边也应合并到一起
            }
        }
        for(int j = 2 ;j <= n;j++){
            if(fat[j] == j)sum++;
        }//计算连通块数量
        ans = (ans + g * ksm(sum)) % mod;
    }
    ans = ( ans % mod + mod ) %mod;//防止ans变成负数
    cout<<ans;
}