P8906 [USACO22DEC] Breakdown P

· · 题解

P8906 [USACO22DEC] Breakdown P

首先折半搜,将问题形式转化为维护 1 到每个点经过 k 条边的最短路 h_i,其中 k\leq 4。我们只考虑 k = 4

删边倒过来变成加边。

考虑时刻维护任意两点间经过两条边的最短路 f_{i, j}。加边时只会对 i = uj = vf_{i, j} 产生修改,共 \mathcal{O}(n) 种情况。对于 h_i,枚举中转点 j ,计算 f_{1, j} + f_{j, i}。注意到当且仅当 ij 等于 uv 时才会因 f 的修改而对 h 产生贡献,因此对于 i\neq u, v,只需用 f_{1, u / v} + f_{u / v, i} 更新,对于 i = uv 用所有 f_{1, j} + f_{j, i} 更新。

单次加边的时间复杂度均摊 \mathcal{O}(n),注意特判 u = 1

#include <bits/stdc++.h>
using namespace std;
constexpr int N = 300 + 5;
constexpr int M = 1e5 + 5;
int n, m, k;
int ans[M], u[M], v[M], e[N][N];
void cmin(int &x, int y) {
  x = x < y ? x : y;
}
struct solver {
  int st, k;
  int e[N][N], f[N][N], h[N];
  void init(int _k, int _st) {
    k = _k, st = _st;
    memset(e, 0x3f, sizeof(e));
    memset(f, 0x3f, sizeof(f));
    memset(h, 0x3f, sizeof(h));
    if(!k) {
      h[st] = 0;
    }
  }
  void add(int u, int v, int w) {
    e[u][v] = w;
    if(!k) return;
    if(k == 1) {
      if(u == st) {
        h[v] = w;
      }
      return;
    }
    for(int i = 1; i <= n; i++) {
      cmin(f[i][v], e[i][u] + w);
      cmin(f[u][i], w + e[v][i]);
    }
    if(k == 2) {
      for(int i = 1; i <= n; i++) {
        cmin(h[i], f[st][i]);
      }
    }
    auto p = k == 3 ? e : f;
    if(k > 2) {
      if(u == st) {
        for(int i = 1; i <= n; i++) {
          for(int j = 1; j <= n; j++) {
            cmin(h[j], f[st][i] + p[i][j]);
          }
        }
      }
      else {
        for(int i = 1; i <= n; i++) {
          cmin(h[i], f[st][v] + p[v][i]);
          cmin(h[i], f[st][u] + p[u][i]);
          cmin(h[u], f[st][i] + p[i][u]);
          cmin(h[v], f[st][i] + p[i][v]);
        }
      }
    }
  }
} _1, _n;
int main() {
  #ifdef ALEX_WEI
    freopen("1.in", "r", stdin);
    freopen("1.out", "w", stdout);
  #endif
  ios::sync_with_stdio(0);
  cin.tie(0), cout.tie(0);
  cin >> n >> k, m = n * n;
  for(int i = 1; i <= n; i++) {
    for(int j = 1; j <= n; j++) {
      cin >> e[i][j];
    }
  }
  for(int i = 1; i <= m; i++) {
    cin >> u[i] >> v[i];
  }
  int L = k >> 1, R = k - L;
  _1.init(L, 1), _n.init(R, n);
  for(int i = m; i; i--) {
    ans[i] = 0x3f3f3f3f;
    for(int p = 1; p <= n; p++) {
      cmin(ans[i], _1.h[p] + _n.h[p]);
    }
    _1.add(u[i], v[i], e[u[i]][v[i]]);
    _n.add(v[i], u[i], e[u[i]][v[i]]);
  }
  for(int i = 1; i <= m; i++) {
    if(ans[i] > 1e9) cout << "-1\n";
    else cout << ans[i] << "\n";
  }
  return 0;
}