题解 P6573 【[BOI 2017] Toll】

· · 题解

给出一张 n 个点,m 条边的有向图,q 次询问任意两点间最短路。

给出一常数 k,保证每条边 (u, v) 保证 \left\lfloor \frac{u}{k} \right\rfloor + 1 = \left\lfloor \frac{v}{k} \right\rfloor

感觉 NOIP2018 普及了动态 DP 以后,这题变简单了。

对于每个块维护一个矩阵做动态 DP。

设进入这个块时,从起点到每个点的距离依次为

\begin{bmatrix} f_0 & f_1 & \cdots & f_{k - 1} \end{bmatrix}

出完这个块之后,需要从这个块内的边走出去,更新矩阵变成

\begin{bmatrix} g_0 & g_1 & \cdots & g_{k - 1} \end{bmatrix}

首先有一个暴力 DP:

g_j = \min\limits_{i = 0} ^ {k - 1} \left(f_i + w_{i, j}\right)

其中 i, j 表示该块内的 i 号点连向下一个块内 j 号点的边的边权,没有设为 +\infty

学习了动态 DP 之后,知道这个东西是可以写成矩阵 “乘法” 形式的,转移矩阵是:

\begin{bmatrix} w_{0, 0} & w_{0, 1} & \cdots & w_{0, k - 1} \\ w_{1, 0} & w_{1, 1} & \cdots & w_{1, k - 1} \\ \vdots & \vdots & \ddots & \vdots \\ w_{k - 1, 0} & w_{k - 1, 1} & \cdots & w_{k - 1, k - 1} \end{bmatrix}

这个东西是 \min+ 套起来,是有结合律的。所以倍增或线段树维护一下区间连乘积,每次查询的时候直接乘即可。

经过一些细节优化可以做到,时间复杂度 \mathcal{O}((n + q) k ^ 2 \log n)

代码:

#include <algorithm>
#include <cstdio>
#include <cstring>

const int MaxN = 50000, MaxK = 5;
const int MaxLog = 16;
const int INF = 0x3F3F3F3F;

int N, M, K, Q;

struct matrix_t {
  int mat[MaxK][MaxK];
  matrix_t() { memset(mat, 0x3F, sizeof mat); }

  inline friend matrix_t operator*(const matrix_t &a, const matrix_t &b) {
    matrix_t c;
    for (int i = 0; i < K; ++i)
      for (int j = 0; j < K; ++j) {
        int x = INF;
        for (int k = 0; k < K; ++k)
          x = std::min(x, a.mat[i][k] + b.mat[k][j]);
        c.mat[i][j] = x;
      }
    return c;
  }
};

struct vector_t {
  int mat[MaxK];
  vector_t() { memset(mat, 0x3F, sizeof mat); }

  inline friend vector_t operator*(const vector_t &a, const matrix_t &b) {
    vector_t c;
    for (int i = 0; i < K; ++i)
      for (int j = 0; j < K; ++j)
        c.mat[j] = std::min(c.mat[j], a.mat[i] + b.mat[i][j]);
    return c;
  }
};

matrix_t F[MaxLog + 1][MaxN + 5];

void init() {
  scanf("%d %d %d %d", &K, &N, &M, &Q);
  for (int i = 1; i <= M; ++i) {
    int u, v, w;
    scanf("%d %d %d", &u, &v, &w);
    F[0][u / K].mat[u % K][v % K] = std::min(F[0][u / K].mat[u % K][v % K], w);
  }
}

inline void query(vector_t &f, int l, int r) {
  int x = l;
  for (int i = MaxLog; i >= 0; --i)
    if ((r - l + 1) & (1 << i)) {
      f = f * F[i][x];
      x += (1 << i);
    }
}

void solve() {
  for (int i = 1; (1 << i) <= N / K + 1; ++i)
    for (int j = 0; j + (1 << i) - 1 < (N / K + 1); ++j)
      F[i][j] = F[i - 1][j] * F[i - 1][j + (1 << (i - 1))];
  for (int q = 1; q <= Q; ++q) {
    int a, b;
    scanf("%d %d", &a, &b);
    if (a / K == b / K) puts("-1");
    else {
      vector_t f;
      f.mat[a % K] = 0;
      query(f, a / K, b / K - 1);
      if (f.mat[b % K] == INF) puts("-1");
      else printf("%d\n", f.mat[b % K]);
    }
  }
}

int main() {
  init();
  solve();
  return 0;
}