【题解】AT3611 Tree MST
LinkyChristian · · 题解
xpp杂题清除计划
这题和P8207 [THUPC2022 初赛] 最小公倍树有点像,都可以用优化选边的 kruskal 解决。
很显然,对于一张边数为
但是也正是因为这是一张完全图,所以我们可以用一个定理来解决问题。
定理:对于任意一张完全图
G=(V,E) ,选取数个边集(E_1,E_2,E_3,...,E_k) 使其完全覆盖边集E 。对每个集合E_i 求最小生成树,得到边集E_{MST_{i}} 。再对E_{MST_{1}},E_{MST_{2}},E_{MST_{3}},...,E_{MST_{k}} 求最小生成树,其最小生成树一定也是E 的最小生成树。
证明:考虑反证。假设最终图有一种比用
因此考虑淀粉质,假设我们以当前的分治中心
#include<bits/stdc++.h>
#define N 200010
using namespace std;
typedef long long ll;
int n,en;
ll w[N],INF=1e18;
int cnt,head[N],to[N<<1],nxt[N<<1];
ll val[N<<1];
int S,son[N],siz[N],vis[N],rt;
struct edge{
int u,v;
ll w;
const bool operator < (const edge o) const {return w<o.w;}
}e[N<<4];
void insert(int u,int v,int w) {
cnt++;
to[cnt]=v;
val[cnt]=w;
nxt[cnt]=head[u];
head[u]=cnt;
}
int read() {
int res=0,f=1;char ch=getchar();
while(!isdigit(ch)) f=ch=='-'?-1:1,ch=getchar();
while(isdigit(ch)) res=res*10+ch-'0',ch=getchar();
return f*res;
}
void gtrt(int now,int fa) {
siz[now]=1,son[now]=0;
for(int i=head[now]; i; i=nxt[i]) if(to[i]!=fa&&!vis[to[i]]) {
gtrt(to[i],now);
siz[now]+=siz[to[i]];
son[now]=max(son[now],siz[to[i]]);
}
son[now]=max(son[now],S-siz[now]);
if(son[now]<son[rt]) rt=now;
}
ll sk[N];
int id[N],tp,mn;
void dfs(int now,int fa,ll dis) {
sk[++tp]=w[now]+dis,id[tp]=now;
if(sk[tp]<sk[mn]) mn=tp;
for(int i=head[now]; i; i=nxt[i]) if(!vis[to[i]]&&to[i]!=fa)
dfs(to[i],now,dis+val[i]);
}
void solve(int now) {
vis[now]=1;
tp=mn=0,dfs(now,0,0);
for(int i=1; i<=tp; i++) if(i!=mn) e[++en]=edge{id[i],id[mn],sk[i]+sk[mn]};
int tmp=S;
for(int i=head[now]; i; i=nxt[i]) if(!vis[to[i]]) {
S=siz[to[i]]<siz[now]?siz[to[i]]:tmp-siz[now];
rt=0,gtrt(to[i],now),solve(rt);
}
}
int fa[N];
int find(int x) {
while(fa[x]!=x) x=fa[x]=fa[fa[x]];
return x;
}
int main()
{
n=read();
for(int i=1; i<=n; i++) w[i]=read();
for(int i=1; i<n; i++) {
int u=read(),v=read(),w=read();
insert(u,v,w);
insert(v,u,w);
}
S=son[0]=n,rt=0,gtrt(1,0),sk[0]=INF;
solve(rt);
for(int i=1; i<=n; i++) fa[i]=i;
sort(e+1,e+en+1);int tot=0;ll sum=0;
for(int i=1; i<=en; i++) {
int u=find(e[i].u),v=find(e[i].v);
if(u!=v) fa[u]=v,tot++,sum+=e[i].w;
if(tot==n-1) {printf("%lld",sum);return 0;}
}
return 0;
}
我是不会说我开 nlogn 的数组只开了 n*4 导致错了好几发的