题解 AT3611 【Tree MST】

· · 题解

题意 : 给定一棵 n 个点的树,边有边权。

按照如下规则建立一张完全图 : x,y之间的边长为w[x]+w[y]+dis(x,y)

求完全图的 MST 边权和。

------------ - $O(n\log^2 n)

一般地,对于(完全图)MST问题,我们可以先选定一个边集,做一次MST(不连通不管),把剩余的边保留,最后再做一次MST,这样一定能得到最优解。

这是个树形结构,我们可以点分治,考虑跨越重心的路径所生成的边集。

以分治中心为根,令p[u]=w[u]+dep[u],则连接两个点的代价就是p[u]+p[v](关建边的长度)

我们只需保留p最小的一个点,然后把其他点都和他相连,显然就是MST了。

这里会产生子树内自己连的路径,但是比直接连劣所以不会影响答案。

点分治一共会产生O(n\log n)条边,然后跑一个kruskalO(n\log^2 n)

#include<algorithm>
#include<cstdio>
#include<vector>
#define ll long long
#define MaxN 200500
using namespace std;
inline int read()
{
  int X=0;char ch=0;
  while(ch<48||ch>57)ch=getchar();
  while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
  return X;
}
vector<int> g[MaxN],l[MaxN];
int tp[MaxN],tn,ms[MaxN],siz[MaxN];
bool vis[MaxN];
void pfs(int u,int fa)
{
  tp[++tn]=u;
  siz[u]=1;ms[u]=0;
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa&&!vis[v]){
      pfs(v,u);
      siz[u]+=siz[v];
      ms[u]=max(ms[u],siz[v]);
    }
}
int getrt(int u)
{
  tn=0;pfs(u,0);
  int rt=0;
  for (int i=1;i<=tn;i++){
    ms[tp[i]]=max(ms[tp[i]],tn-siz[tp[i]]);
    if (ms[tp[i]]<ms[rt])rt=tp[i];
  }return rt;
}
ll w[MaxN],dep[MaxN],mp;
void dfs(int u,int fa)
{
  if (w[mp]+dep[mp]>w[u]+dep[u])mp=u;
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa&&!vis[v]){
      dep[v]=dep[u]+l[u][i];
      dfs(v,u);
    }
}
struct Line
{int f,t;ll len;}s[MaxN<<5];
int tot;
void solve(int u)
{
  dep[u]=mp=0;dfs(u,0);
  ll sav=w[mp]+dep[mp];
  for (int i=1;i<=tn;i++)
    s[++tot]=(Line){tp[i],mp,sav+w[tp[i]]+dep[tp[i]]};
  vis[u]=1;
  for (int i=0,v;i<g[u].size();i++)
    if (!vis[v=g[u][i]])
      solve(getrt(v));
}
int n,f[MaxN];
bool cmp(const Line &A,const Line &B)
{return A.len<B.len;}
int findf(int u)
{return f[u]==u ? u : f[u]=findf(f[u]);}
bool merge(int x,int y)
{
  x=findf(x);y=findf(y);
  if (x==y)return 0;
  f[x]=y;return 1;
}
int main()
{
  n=read();w[0]=1ll<<60;ms[0]=n+1;
  for (int i=1;i<=n;i++)w[i]=read();
  for (int i=1,fr,to,len;i<n;i++){
    fr=read();to=read();len=read();
    g[fr].push_back(to);
    l[fr].push_back(len);
    g[to].push_back(fr);
    l[to].push_back(len);
  }solve(getrt(1));
  ll ans=0;
  for (int i=1;i<=n;i++)f[i]=i;
  sort(s+1,s+tot+1,cmp);
  for (int i=1;i<=tot;i++)
    if (merge(s[i].f,s[i].t))
      ans+=s[i].len;
  printf("%lld",ans);
}

考虑使用用Boruvka算法。

每一轮,我们要对每个连通块找到最小出边(到达其他连通块)

对于一个已经联通块染色的局面,我们可以进行up and down DP.

首先求出从子树到达本身的最优决策,问题来了,颜色不能相同……

\texttt{stO EI Orz}

我们记录最大值和次大值即可,对颜色去重,这样万一最大的碰了还有补刀的。 (本蒟蒻这里可能写丑了)

然后再次dfs计算从上面延伸过来的即可。

每轮会使连通块个数减半,复杂度就是O(n\log n),常数较大。

#include<algorithm>
#include<cstdio>
#include<vector>
#define ll long long
#define MaxN 200500
using namespace std;
inline int read()
{
  int X=0;char ch=0;
  while(ch<48||ch>57)ch=getchar();
  while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
  return X;
}
vector<int> g[MaxN],l[MaxN];
int w[MaxN],c[MaxN];
#define Pr pair<ll,int>
#define mp make_pair
#define fir first
#define sec second
struct Data
{
  Pr x,x2;int c;
  bool chk(const Data &t){
    if (t.c==c){
      x=min(x,t.x);
      x2=min(x2,t.x2);
    }else {
      if (t.x<=x){
        x2=min(x,t.x2);
        c=t.c;x=t.x;
      }else x2=min(x2,t.x);
    }
  }
  Pr get(int tc){
    if (tc==c)return x2;
    return x;
  }
}f[MaxN];
void dfs1(int u,int fa)
{
  f[u].x=mp(w[u],u);
  f[u].c=c[u];
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa){
      dfs1(v,u);
      f[v].x.fir+=l[u][i];
      f[v].x2.fir+=l[u][i];
      f[u].chk(f[v]);
      f[v].x.fir-=l[u][i];
      f[v].x2.fir-=l[u][i];
    }
}
void dfs2(int u,int fa)
{
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa){
      f[u].x.fir+=l[u][i];
      f[u].x2.fir+=l[u][i];
      f[v].chk(f[u]);
      dfs2(v,u);
      f[u].x.fir-=l[u][i];
      f[u].x2.fir-=l[u][i];
    }
}
int n,cnt;
int findf(int u)
{return c[u]==u ? u : c[u]=findf(c[u]);}
bool merge(int x,int y)
{
  x=findf(x);y=findf(y);
  if (x==y)return 0;
  c[x]=y;return 1;
}
#define INF (1ll<<60)
Pr p[MaxN];
int main()
{
  n=read();
  for (int i=1;i<=n;i++)w[i]=read();
  for (int i=1,fr,to,len;i<n;i++){
    fr=read();to=read();len=read();
    g[fr].push_back(to);
    l[fr].push_back(len);
    g[to].push_back(fr);
    l[to].push_back(len);
  }for (int i=1;i<=n;i++)c[i]=i;
  ll ans=0;
  while(cnt<n-1){
    for (int i=1;i<=n;i++){
      p[i].fir=f[i].x.fir=f[i].x2.fir=INF;
      f[i].x.sec=f[i].x2.sec=f[i].c=0;
    }dfs1(1,0);dfs2(1,0);
    for (int i=1;i<=n;i++){
      Pr sav=f[i].get(c[i]);
      sav.fir+=w[i];
      p[c[i]]=min(p[c[i]],sav);
    }for (int i=1;i<=n;i++)
      if (p[i].fir<INF&&merge(i,p[i].sec))
        {ans+=p[i].fir;cnt++;}
    for (int i=1;i<=n;i++)findf(i);
  }printf("%lld",ans);
}