题解:AT_abc235_h [ABC235Ex] Painting Weighted Graph

· · 题解

Solution

建出多叉 Kruskal 重构树。

问题变为:给你一棵树,每次能将一个节点的子树中所有叶子结点标记。问不超过 k 次操作能标记出多少种可能的叶子集合。

考虑设 dp_{u,k} 表示,u 的子树内,需要覆盖次数至少为 k 的方案数。

直接使用树形背包转移即可。注意如果所有子树都是全满(这种情况只会出现一次而且容易处理),应当 dp_{u,\deg_u} \leftarrow dp_{u,\deg_u}-1dp_{u,1} \leftarrow dp_{u,1}+1

复杂度 O(nk)。常数可能会比较大,看你实现。

#include<bits/stdc++.h>
#define ll long long
#define ffor(i,a,b) for(int i=(a);i<=(b);i++)
#define roff(i,a,b) for(int i=(a);i>=(b);i--)
using namespace std;
const int MAXN=2e5+10,MOD=998244353;
int n,m,k,tot,rt[MAXN],sze[MAXN],fa[MAXN],ok[MAXN];
map<int,vector<pair<int,int>>> ed;
int find(int k) {return (fa[k]==k)?k:(fa[k]=find(fa[k]));}
vector<int> G[MAXN],vc[MAXN],dp[MAXN];
void merge(vector<int>& u,vector<int>& v) {
    if(u.size()<v.size()) swap(u,v);
    for(auto id:v) u.push_back(id);
    vector<int> ().swap(v);
    return ;    
}
int tmp[MAXN];
void dfs(int u) {
    if(u<=n) return sze[u]=1,dp[u].push_back(1),dp[u].push_back(1),void();
    dp[u].push_back(1),dp[u].push_back(0),sze[u]=0;
    for(auto v:G[u]) {
        dfs(v);
        ffor(i,0,min(sze[u]+sze[v],k)) tmp[i]=0;
        ffor(i,0,sze[u]) ffor(j,0,min(sze[v],k-i)) tmp[i+j]=(tmp[i+j]+1ll*dp[u][i]*dp[v][j])%MOD;
        sze[u]+=sze[v],dp[u].resize(min(sze[u],k)+1);
        ffor(i,0,min(sze[u],k)) dp[u][i]=tmp[i];
        vector<int>().swap(dp[v]);
    }
    if(!ok[u]) {
        if(G[u].size()<dp[u].size()) dp[u][G[u].size()]=(dp[u][G[u].size()]-1)%MOD;
        dp[u][1]=(dp[u][1]+1)%MOD;
    }
    return ;
}
signed main() {
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    cin>>n>>m>>k,tot=n;
    ffor(i,1,m) {int a,b,c;cin>>a>>b>>c,ed[c].push_back({a,b});}
    ffor(i,1,n) fa[i]=i,rt[i]=i;
    for(auto pr:ed) {
        auto vcc=pr.second;
        set<int> occ;
        for(auto e:vcc) {
            int u=e.first,v=e.second;
            u=find(u),v=find(v);
            if(u==v) continue ;
            occ.insert(u),occ.insert(v);
            if(!vc[u].size()) vc[u].push_back(rt[u]);
            if(!vc[v].size()) vc[v].push_back(rt[v]);
            fa[v]=u,merge(vc[u],vc[v]);
        }
        for(auto id:occ) if(find(id)==id) {
            ++tot;
            for(auto s:vc[id]) G[tot].push_back(s);
            rt[id]=tot;
            vector<int> ().swap(vc[id]);
        }
    }
    set<int> st;
    ffor(i,1,n) if(find(i)==i) st.insert(rt[i]);
    if(st.size()!=1) {++tot,ok[tot]=1;for(auto id:st) G[tot].push_back(id);}
    dfs(tot);
    int ans=0;
    ffor(i,0,dp[tot].size()-1) ans=(ans+dp[tot][i])%MOD;
    cout<<(ans%MOD+MOD)%MOD;
    return 0;
}