P8315 [COCI2021-2022#4] Šarenlist

· · 题解

P8315 [COCI2021-2022#4] Šarenlist

分析

首先合法路径上至少有两条边颜色不同,这不太好统计,我们可以反过来想:

合法方案数 = 总方案数 - 不合法的方案数

总方案数显然是 k^{n-1}, 考虑计算不合法方案数。

先考虑 m=1 的情况:

显然,不合法方案的这一条路径上的颜色都是相同的,我们就想到像缩点一样,把这一条路径的边缩成一条边,这样不合法方案数就好算了,设这一条路径包含 c 条边,则不合法方案数为 k^{n-c}

然而,如果 m\geq2,这样算会重复计算多条路径同时不合法的方案,我们发现 n\leq60,m\leq15,这么小!于是可以直接暴力枚举容斥。

具体来说,就是枚举不合法路径的集合 S,对于每条路径,把它所有边合并到某一条边上(如果路径有相交,当然是合并成一条边了)。然后统计合并后的边数 cnt,当 |S| 为偶数,让答案加上 cnt, 当 |S| 为奇数,让答案减去 cnt,枚举路径集合可以直接状态压缩,合并可以用并查集实现,时间复杂度为 O(nm2^m),具体细节看代码。

代码

#include <bits/stdc++.h>
#define N 650
#define M 1500
#define ll long long//记得开long long
using namespace std;
const ll mod=1e9+7;
int n,m;
ll k;
int tot,first[N],nxt[M],ver[M];
void add(int s,int e) {
    nxt[++tot]=first[s];
    first[s]=tot;
    ver[tot]=e;
}
int id[M],idc;
int s[N],t[N];
void init() {
    for(int i=1;i<n;++i) {
        int s,e;
        scanf("%d%d",&s,&e);
        add(s,e),add(e,s);
        id[tot]=id[tot-1]=++idc;//给边编号方便存路径
    }
    for(int i=1;i<=m;++i) {
        scanf("%d%d",&s[i],&t[i]);
    }
}
int road[N][N],cr[N];//road存路径
bool dfs(int now,int u,int T,int fa) {
    if(now==T) {
        return 1;
    }
    for(int i=first[now];i;i=nxt[i]) {
        int v=ver[i];
        if(v==fa) continue;
        if(dfs(v,u,T,now)) {
            road[u][++cr[u]]=id[i];
            return 1;
        }
    }
    return 0;
}
int fa[N];//并查集
void reset() {
    for(int i=1;i<=idc;++i) {
        fa[i]=i;
    }
}
int get(int p) {
    if(fa[p]==p) return p;
    return fa[p]=get(fa[p]);
}
void merge(int x,int y) {
    fa[get(x)]=get(y);
}
bitset<N> vis;
ll powk[N],ans;
int main() {
    scanf("%d%d%lld",&n,&m,&k);
    init();
    for(int i=1;i<=m;++i) {
        dfs(s[i],i,t[i],0);
    }
    powk[0]=1;
    for(int i=1;i<=n;++i) {
        powk[i]=powk[i-1]*k%mod;//预处理k的幂
    }
    int ed=1<<m;
    ans=powk[n-1];
    for(int i=1;i<ed;++i) {//状态压缩
        int cnt=0,c1=0;//cnt存缩边后边数,c1表示统计了几条路径
        reset();
        vis.reset();
        for(int j=0;j<m;++j) {
            if((i>>j)&1) {
                c1++;
                for(int k=2;k<=cr[j+1];++k) {
                    merge(road[j+1][k],road[j+1][1]);
              //把路径上的边合并到第一条边上
                }
            }
        }
        for(int j=1;j<=idc;++j) {
            if(!vis[get(j)]) {
                vis[get(j)]=1;
                cnt++;
            }
        }
        if(c1&1) {
            ans=(ans-powk[cnt]+mod)%mod;
        } else {
            ans=(ans+powk[cnt])%mod;
        }
    }
    printf("%lld\n",ans);
    return 0;
}