题解:P6199 [EER1] 河童重工

· · 题解

就让我站在巨人的肩膀上吧!

博客园

本文可能会有一些细节没讲到。

如果还有其他方法,欢迎在评论区补充。

luogu - P6199 [EER1] 河童重工

给两棵树 T1,T2n 个节点的完全图,任意两点之间的边权为 T1.dis(i,j) + T2.dis(i,j),求MST。

很显然这是 [AT_cf17_final_j] Tree MST 的加强版。

法一

点分治套点分治。

直接得到可能用到的 O(n \log^2 n) 条边(点分治,然后虚树,然后点分治),最后统一求一遍 MST。

复杂度 O(n \log^2 n + n \log^2 n \log (n \log^2 n))

法二

点分治套 brovka。

直接得到可能用到的 O(n \log n) 条边(点分治,然后虚树,然后 brovka),最后统一求一遍 MST。

复杂度 O(n \log^2 n + n \log^2 n \ \alpha(n) + n \log n \log (n \log n))

法二的进阶

对每个虚树都求 MST 还是太麻烦了,有没有更好的办法?有的兄弟,有的。(参考 https://www.luogu.com.cn/article/qu23uw2l )

假设对 T2 点分治,我们把两点距离算 T2.dis(i,root) + T2.dis(j,root) + T1.dis(i,j),显然不影响最终答案。设 w_i = T2.dis(i,root)。设 E 表示最终 MST 用到的边集。

对于一棵虚树,每个点 i,求 w_j + T1.dis(i,j) 最小的 j,记作 pre_i,且记 w_j + T1.dis(i,j)w'_i(关于求 j,这是一个基础的 dijkstra 问题……堆初始 push 所有点即可)。

对于虚数上连接的两点 u,v,在 E 中加入 (pre_u,pre_v),边权为 w'_u + w'_v + T1.dis(u,v)

对每个点 u,在 E 中加入 (u,pre_u),边权为 w'_u + w_u

这样最后得到的是 O(n \log n) 条边,统一求一遍 MST。

这为什么正确?证明:

复杂度 O(n \log^2 n + n \log^2 n + n \log n \log(n \log n))

法三

brovka 套点分治。(参考 https://www.luogu.com.cn/article/97gxfx4s )

brovka 后,每一次,需要对每个 i 求出 T1.dis(i,j) + T2.dis(i,j) 最小且满足 col_i \ne col_jj

假设对 T1 用 brovka,每一层,对 T1 点分治,求出当前子树在 T2 中的虚树,树形 dp(额外记录次小值,满足最小值和次小值的 col 不同)。

每一次的复杂度 O(n \log^2 n \times 虚树的常数),再乘上 brokva 的 \log n 直接爆炸。

可以把点分树和每一层的虚树建出来并保留,那么 brovka 的每一次,省去了建虚树的复杂度,那么 O(n \log n)

最终复杂度 O(n \log^2 n + n \log^2 n)

Code

这里选择法二,个人感觉相比法三要好写一些,因为蒟蒻不太会存虚树,感觉存虚树有点麻烦……难评。

我写代码的时候出了几个问题:

这题的样例还是太弱了,作为一个善良的人,赠送一组样例:

9
1 2 7
2 3 5
3 4 6
4 5 3
1 6 7
1 7 1
3 8 1
1 9 8
1 2 7
1 3 1
3 4 9
3 5 2
3 6 4
6 7 7
4 8 4
2 9 0

102
#include <bits/stdc++.h>

using namespace std;
using LL = long long;
using PII = pair<int, int>;

const int MAXN = 1e5 + 3, MAXL = 18;

struct Edge{
  int u, v, w;
};

int n;

inline void Merge(vector<Edge> &x, vector<Edge> y){
  x.insert(x.end(), y.begin(), y.end());
}

namespace Tree1{
  int anc[MAXN][MAXL], dep[MAXN], ldep[MAXN], dfn[MAXN], depth = 0;
  vector<PII> EG[MAXN];
  void ADD(int U, int V, int W){
    EG[U].push_back({V, W}), EG[V].push_back({U, W});
  }
  void dfs(int x, int dad){
    anc[x][0] = dad, dfn[x] = ++depth;
    for(PII e : EG[x]){
      if(e.first == dad) continue;
      dep[e.first] = dep[x] + 1, ldep[e.first] = ldep[x] + e.second;
      dfs(e.first, x);
    }
  }
  void ply_m(){
    dep[1] = ldep[1] = 0, dfs(1, 0);
    for(int l = 1; l < MAXL; l++){
      for(int i = 1; i <= n; i++) anc[i][l] = anc[anc[i][l-1]][l-1];
    }
  }
  int LCA(int x, int y){
    if(dep[x] > dep[y]) swap(x, y);
    for(int k = dep[y] - dep[x], l = 0; l < MAXL; l++){
      if((k >> l) & 1) y = anc[y][l];
    }
    if(x == y) return x;
    for(int l = MAXL - 1; l >= 0; l--){
      if(anc[x][l] != anc[y][l]) x = anc[x][l], y = anc[y][l];
    }
    return anc[x][0];
  }

  int dp[MAXN], pre[MAXN];
  vector<PII> eg[MAXN];
  vector<Edge> Solve(vector<PII> S){
    vector<PII> vt;
    //cout << S.size() << " %%%%%%%%%%\n";
    //for(PII x : S) cout << x.first << " " << x.second << " $$$\n";
    sort(S.begin(), S.end(), [](PII x, PII y){ return dfn[x.first] < dfn[y.first]; });
    vt.push_back(S[0]);
    for(int i = 1; i < S.size(); i++){
      vt.push_back(S[i]), vt.push_back({LCA(S[i].first, S[i - 1].first), int(2e9)});
    }
    sort(vt.begin(), vt.end(), [](PII x, PII y){ return x.first == y.first ? x.second < y.second : dfn[x.first] < dfn[y.first]; });

    priority_queue<PII, vector<PII>, greater<PII>> pq;
    vector<PII> E, V;
    for(PII x : vt) eg[x.first].clear();
    pq.push({vt[0].second, vt[0].first}), dp[vt[0].first] = vt[0].second, pre[vt[0].first] = vt[0].first, V.push_back(vt[0]);
    for(int i = 1, la = vt[0].first; i < vt.size(); i++){
      if(vt[i].first != vt[i - 1].first){
        int x = vt[i].first, f = LCA(x, la);
        eg[f].push_back({x, ldep[x] - ldep[f]}), eg[x].push_back({f, ldep[x] - ldep[f]});
        pq.push({vt[i].second, x}), dp[x] = vt[i].second, pre[x] = x;
        E.push_back({x, f}), V.push_back(vt[i]);
        la = x;
      }
    }
    while(!pq.empty()){
      PII i = pq.top();
      pq.pop();
      swap(i.first, i.second);
      if(dp[i.first] < i.second) continue;
      for(PII e : eg[i.first]){
        int nx = e.first;
        LL nw = e.second + dp[i.first];
        if(dp[nx] > nw){
          dp[nx] = nw, pre[nx] = pre[i.first], pq.push({nw, nx});
        }
      }
    }
    vector<Edge> ret;
    for(PII e : E){
      //cout << e.first << " " << pre[e.first] << " " <<e.second << " " << pre[e.second ] << "\n";
      if(pre[e.first] != pre[e.second]){
        ret.push_back({pre[e.first], pre[e.second], dp[e.first] + dp[e.second] + abs(ldep[e.first] - ldep[e.second])});
      }
    }
    for(PII x : V){
      if(pre[x.first] != x.first && dp[x.first] < int(2e9) && x.second < int(2e9)){
        ret.push_back({pre[x.first], x.first, dp[x.first] + x.second});
      }
    }
    //cout << ret.size() << "\n";
    return ret;
  }
}

namespace Tree2{
  vector<Edge> ret;
  int vis[MAXN], Size, root, sz[MAXN], mx[MAXN];
  vector<PII> eg[MAXN];

  void ADD(int U, int V, int W){
    eg[U].push_back({V, W}), eg[V].push_back({U, W});
  }

  void Get_root(int x, int dad){
    sz[x] = 1, mx[x] = 0;
    for(PII e: eg[x]){ int nxt = e.first;
      if(vis[nxt] || nxt == dad) continue;
      Get_root(nxt, x), sz[x] += sz[nxt], mx[x] = max(mx[x], sz[nxt]);
    }
    mx[x] = max(mx[x], Size - sz[x]);
    if(root == 0 || mx[root] > mx[x]) root = x;
  }
  vector<PII> S;
  void dfs(int x, int dad, int ldep){
    S.push_back({x, ldep});
    for(PII e : eg[x]){ int nxt = e.first;
      if(vis[nxt] || nxt == dad) continue;
      dfs(nxt, x, ldep + e.second);
    }
  }
  void Solve(int x){
    vis[x] = 1; 
    S.clear(), dfs(x, 0, 0), Merge(ret, Tree1::Solve(S));
    for(PII e : eg[x]){ int nxt = e.first;
      if(vis[nxt]) continue;
      Size = sz[nxt], root = 0, Get_root(nxt, x), Solve(root);
    }
  }
}

int fa[MAXN];
int Getf(int x){ return fa[x] == x ? x : fa[x] = Getf(fa[x]); }

int main(){
  ios::sync_with_stdio(0), cin.tie(0);
  cin >> n;
  for(int i = 1, U, V, W; i < n; i++){
    cin >> U >> V >> W, Tree1::ADD(U, V, W);
  }
  for(int i = 1, U, V, W; i < n; i++){
    cin >> U >> V >> W, Tree2::ADD(U, V, W);
  }
  Tree1::ply_m();
  Tree2::Size = n, Tree2::root = 0, Tree2::Get_root(1, 0), Tree2::Solve(Tree2::root);
  vector<Edge> E = Tree2::ret;
  sort(E.begin(), E.end(), [](Edge i, Edge j){ return i.w < j.w; });
  for(int i = 1; i <= n; i++) fa[i] = i;
  LL ans = 0;
  for(Edge e : E){
    //cout << e.u << " " << e.v << " " << e.w << " ^\n";
    int fx = Getf(e.u), fy = Getf(e.v);
    if(fx != fy){
      fa[fx] = fy, ans += e.w;
    }
  }
  cout << ans;
  return 0;
}
/*
9
1 2 7
2 3 5
3 4 6
4 5 3
1 6 7
1 7 1
3 8 1
1 9 8
1 2 7
1 3 1
3 4 9
3 5 2
3 6 4
6 7 7
4 8 4
2 9 0

102

*/