P11343 [KTSC 2023 R1] 出租车旅行

· · 题解

给定一棵 n 个点的树,边有边权。有一个 n 个点的有向完全图,i \to j 边权为 a_i + \operatorname{dis}(i, j) \times b_i。对于所有 2 \le i \le n,求 1 \rightsquigarrow i 的最短路。

---- **性质**:最短路除了终点的其他点 $b_i$ 递减。 考虑 DP,设 $f_u$ 为 $1 \rightsquigarrow u$ 且路径所有点 $b_u$ 递减的最短路长度,有: $$ f_u = \min\limits_{b_v > b_u} f_v + a_v + \operatorname{dis}(u, v) \times b_v $$ 再设 $g_u$ 为 $1 \rightsquigarrow u$ 的最短路长度,有: $$ g_u = \min\limits_{v} f_v + a_v + \operatorname{dis}(u, v) \times b_v $$ 看到树上距离考虑点分治。 点分树上每个点 $u$ 维护子树内 $(-b_v, f_v + a_v + \operatorname{dis}(u, v) \times b_v)$ 的凸包。由于 $b_v$ 有序所以凸包可以用支持 pop_back 的 vector 维护。 算 $f_u$ 就枚举 $u$ 在点分树上的祖先 $v$,相当于查询斜率为 $\operatorname{dis}(u, v)$ 的直线与凸包的最大截距。二分即可。 算 $g_u$ 可以用同样的方法。 时间复杂度 $O(n \log^2 n)$,空间复杂度 $O(n \log n)$。 ```cpp #include <bits/stdc++.h> #define pb emplace_back #define fst first #define scd second #define mkp make_pair #define mems(a, x) memset((a), (x), sizeof(a)) using namespace std; typedef long long ll; typedef __int128 lll; typedef double db; typedef unsigned long long ull; typedef long double ldb; typedef pair<ll, ll> pii; const int maxn = 100100; const int logn = 20; ll n, a[maxn], b[maxn], dep[maxn], f[maxn], h[maxn]; int st[logn][maxn], dfn[maxn], tim; vector<pii> G[maxn]; vector<int> T[maxn]; inline int get(int i, int j) { return dfn[i] < dfn[j] ? i : j; } inline int qlca(int x, int y) { if (x == y) { return x; } x = dfn[x]; y = dfn[y]; if (x > y) { swap(x, y); } ++x; int k = __lg(y - x + 1); return get(st[k][x], st[k][y - (1 << k) + 1]); } inline ll qdis(int x, int y) { return dep[x] + dep[y] - dep[qlca(x, y)] * 2; } void dfs(int u, int t) { dfn[u] = ++tim; st[0][tim] = t; for (pii p : G[u]) { ll v = p.fst, d = p.scd; if (v == t) { continue; } dep[v] = dep[u] + d; dfs(v, u); } } struct node { ll x, y; node(ll a = 0, ll b = 0) : x(a), y(b) {} }; inline node operator + (const node &a, const node &b) { return node(a.x + b.x, a.y + b.y); } inline node operator - (const node &a, const node &b) { return node(a.x - b.x, a.y - b.y); } inline lll operator * (const node &a, const node &b) { return (lll)a.x * b.y - (lll)a.y * b.x; } int rt, g[maxn], sz[maxn], fa[maxn], p[maxn]; bool vis[maxn]; void dfs2(int u, int fa, int t) { sz[u] = 1; g[u] = 0; for (pii p : G[u]) { int v = p.fst; if (v == fa || vis[v]) { continue; } dfs2(v, u, t); sz[u] += sz[v]; g[u] = max(g[u], sz[v]); } g[u] = max(g[u], t - sz[u]); if (!rt || g[u] < g[rt]) { rt = u; } } void dfs3(int u) { vis[u] = 1; for (pii p : G[u]) { int v = p.fst; if (vis[v]) { continue; } rt = 0; dfs2(v, u, sz[v]); dfs2(rt, u, sz[v]); fa[rt] = u; T[u].pb(rt); dfs3(rt); } } vector<node> S[maxn]; node c[maxn]; int stk[maxn], top, tot, pt[maxn], K; ll e[maxn]; void dfs5(int u, int fa, ll d) { c[++tot] = node(-b[u], f[u] + a[u] + b[u] * d); e[u] = d; pt[++K] = u; for (pii p : G[u]) { ll v = p.fst, k = p.scd; if (v == fa || vis[v]) { continue; } dfs5(v, u, d + k); } } inline void work(int u) { tot = 0; K = 0; dfs5(u, -1, 0); sort(c + 1, c + tot + 1, [&](const node &a, const node &b) { return a.x < b.x || (a.x == b.x && a.y < b.y); }); top = 0; for (int i = 1; i <= tot; ++i) { while (top >= 2 && (c[i] - c[stk[top - 1]]) * (c[stk[top]] - c[stk[top - 1]]) >= 0) { --top; } stk[++top] = i; } tot = top; for (int i = 1; i <= tot; ++i) { c[i] = c[stk[i]]; } for (int i = 1; i <= K; ++i) { int v = pt[i], l = 2, r = tot, p = 1; while (l <= r) { int mid = (l + r) >> 1; if (c[mid].y - c[mid - 1].y <= (c[mid].x - c[mid - 1].x) * e[v]) { p = mid; l = mid + 1; } else { r = mid - 1; } } h[v] = min(h[v], c[p].y - c[p].x * e[v]); } } void dfs4(int u) { vis[u] = 1; work(u); for (int v : T[u]) { dfs4(v); } } vector<ll> travel(vector<ll> _a, vector<int> _b, vector<int> _u, vector<int> _v, vector<int> _w) { n = (int)_a.size(); for (int i = 1; i <= n; ++i) { a[i] = _a[i - 1]; b[i] = _b[i - 1]; } for (int i = 0; i < n - 1; ++i) { int u = _u[i] + 1, v = _v[i] + 1, d = _w[i]; G[u].pb(v, d); G[v].pb(u, d); } dfs(1, -1); for (int j = 1; (1 << j) <= n; ++j) { for (int i = 1; i + (1 << j) - 1 <= n; ++i) { st[j][i] = get(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]); } } dfs2(1, -1, n); int x = rt; dfs2(x, -1, n); dfs3(x); for (int i = 1; i <= n; ++i) { p[i] = i; } sort(p + 2, p + n + 1, [&](const int &x, const int &y) { return b[x] > b[y]; }); mems(f, 0x3f); for (int i = 1; i <= n; ++i) { int u = p[i]; if (i > 1 && b[u] > b[1]) { continue; } if (i == 1) { f[u] = 0; } else { for (int v = u; v; v = fa[v]) { if (S[v].empty()) { continue; } int l = 1, r = (int)S[v].size() - 1, p = 0; ll k = qdis(u, v); while (l <= r) { int mid = (l + r) >> 1; if (S[v][mid].y - S[v][mid - 1].y <= (S[v][mid].x - S[v][mid - 1].x) * k) { p = mid; l = mid + 1; } else { r = mid - 1; } } f[u] = min(f[u], S[v][p].y - S[v][p].x * k); } } for (int v = u; v; v = fa[v]) { node w(-b[u], f[u] + a[u] + b[u] * qdis(u, v)); while ((int)S[v].size() >= 2 && (w - S[v][(int)S[v].size() - 2]) * (S[v].back() - S[v][(int)S[v].size() - 2]) >= 0) { S[v].pop_back(); } S[v].pb(w); } } mems(h, 0x3f); mems(vis, 0); dfs4(x); vector<ll> vc; for (int i = 2; i <= n; ++i) { vc.pb(h[i]); } return vc; } // int main() { // int n; // scanf("%d", &n); // vector<ll> a(n); // vector<int> b(n); // for (ll &x : a) { // scanf("%lld", &x); // } // for (int &x : b) { // scanf("%d", &x); // } // vector<int> u(n - 1), v(n - 1), w(n - 1); // for (int i = 0; i < n - 1; ++i) { // scanf("%d%d%d", &u[i], &v[i], &w[i]); // } // auto ans = travel(a, b, u, v, w); // for (ll x : ans) { // printf("%lld\n", x); // } // return 0; // } ```