P11343 [KTSC 2023 R1] 出租车旅行
EuphoricStar
·
·
题解
给定一棵 n 个点的树,边有边权。有一个 n 个点的有向完全图,i \to j 边权为 a_i + \operatorname{dis}(i, j) \times b_i。对于所有 2 \le i \le n,求 1 \rightsquigarrow i 的最短路。
----
**性质**:最短路除了终点的其他点 $b_i$ 递减。
考虑 DP,设 $f_u$ 为 $1 \rightsquigarrow u$ 且路径所有点 $b_u$ 递减的最短路长度,有:
$$
f_u = \min\limits_{b_v > b_u} f_v + a_v + \operatorname{dis}(u, v) \times b_v
$$
再设 $g_u$ 为 $1 \rightsquigarrow u$ 的最短路长度,有:
$$
g_u = \min\limits_{v} f_v + a_v + \operatorname{dis}(u, v) \times b_v
$$
看到树上距离考虑点分治。
点分树上每个点 $u$ 维护子树内 $(-b_v, f_v + a_v + \operatorname{dis}(u, v) \times b_v)$ 的凸包。由于 $b_v$ 有序所以凸包可以用支持 pop_back 的 vector 维护。
算 $f_u$ 就枚举 $u$ 在点分树上的祖先 $v$,相当于查询斜率为 $\operatorname{dis}(u, v)$ 的直线与凸包的最大截距。二分即可。
算 $g_u$ 可以用同样的方法。
时间复杂度 $O(n \log^2 n)$,空间复杂度 $O(n \log n)$。
```cpp
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef __int128 lll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 100100;
const int logn = 20;
ll n, a[maxn], b[maxn], dep[maxn], f[maxn], h[maxn];
int st[logn][maxn], dfn[maxn], tim;
vector<pii> G[maxn];
vector<int> T[maxn];
inline int get(int i, int j) {
return dfn[i] < dfn[j] ? i : j;
}
inline int qlca(int x, int y) {
if (x == y) {
return x;
}
x = dfn[x];
y = dfn[y];
if (x > y) {
swap(x, y);
}
++x;
int k = __lg(y - x + 1);
return get(st[k][x], st[k][y - (1 << k) + 1]);
}
inline ll qdis(int x, int y) {
return dep[x] + dep[y] - dep[qlca(x, y)] * 2;
}
void dfs(int u, int t) {
dfn[u] = ++tim;
st[0][tim] = t;
for (pii p : G[u]) {
ll v = p.fst, d = p.scd;
if (v == t) {
continue;
}
dep[v] = dep[u] + d;
dfs(v, u);
}
}
struct node {
ll x, y;
node(ll a = 0, ll b = 0) : x(a), y(b) {}
};
inline node operator + (const node &a, const node &b) {
return node(a.x + b.x, a.y + b.y);
}
inline node operator - (const node &a, const node &b) {
return node(a.x - b.x, a.y - b.y);
}
inline lll operator * (const node &a, const node &b) {
return (lll)a.x * b.y - (lll)a.y * b.x;
}
int rt, g[maxn], sz[maxn], fa[maxn], p[maxn];
bool vis[maxn];
void dfs2(int u, int fa, int t) {
sz[u] = 1;
g[u] = 0;
for (pii p : G[u]) {
int v = p.fst;
if (v == fa || vis[v]) {
continue;
}
dfs2(v, u, t);
sz[u] += sz[v];
g[u] = max(g[u], sz[v]);
}
g[u] = max(g[u], t - sz[u]);
if (!rt || g[u] < g[rt]) {
rt = u;
}
}
void dfs3(int u) {
vis[u] = 1;
for (pii p : G[u]) {
int v = p.fst;
if (vis[v]) {
continue;
}
rt = 0;
dfs2(v, u, sz[v]);
dfs2(rt, u, sz[v]);
fa[rt] = u;
T[u].pb(rt);
dfs3(rt);
}
}
vector<node> S[maxn];
node c[maxn];
int stk[maxn], top, tot, pt[maxn], K;
ll e[maxn];
void dfs5(int u, int fa, ll d) {
c[++tot] = node(-b[u], f[u] + a[u] + b[u] * d);
e[u] = d;
pt[++K] = u;
for (pii p : G[u]) {
ll v = p.fst, k = p.scd;
if (v == fa || vis[v]) {
continue;
}
dfs5(v, u, d + k);
}
}
inline void work(int u) {
tot = 0;
K = 0;
dfs5(u, -1, 0);
sort(c + 1, c + tot + 1, [&](const node &a, const node &b) {
return a.x < b.x || (a.x == b.x && a.y < b.y);
});
top = 0;
for (int i = 1; i <= tot; ++i) {
while (top >= 2 && (c[i] - c[stk[top - 1]]) * (c[stk[top]] - c[stk[top - 1]]) >= 0) {
--top;
}
stk[++top] = i;
}
tot = top;
for (int i = 1; i <= tot; ++i) {
c[i] = c[stk[i]];
}
for (int i = 1; i <= K; ++i) {
int v = pt[i], l = 2, r = tot, p = 1;
while (l <= r) {
int mid = (l + r) >> 1;
if (c[mid].y - c[mid - 1].y <= (c[mid].x - c[mid - 1].x) * e[v]) {
p = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
h[v] = min(h[v], c[p].y - c[p].x * e[v]);
}
}
void dfs4(int u) {
vis[u] = 1;
work(u);
for (int v : T[u]) {
dfs4(v);
}
}
vector<ll> travel(vector<ll> _a, vector<int> _b, vector<int> _u, vector<int> _v, vector<int> _w) {
n = (int)_a.size();
for (int i = 1; i <= n; ++i) {
a[i] = _a[i - 1];
b[i] = _b[i - 1];
}
for (int i = 0; i < n - 1; ++i) {
int u = _u[i] + 1, v = _v[i] + 1, d = _w[i];
G[u].pb(v, d);
G[v].pb(u, d);
}
dfs(1, -1);
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
st[j][i] = get(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
}
}
dfs2(1, -1, n);
int x = rt;
dfs2(x, -1, n);
dfs3(x);
for (int i = 1; i <= n; ++i) {
p[i] = i;
}
sort(p + 2, p + n + 1, [&](const int &x, const int &y) {
return b[x] > b[y];
});
mems(f, 0x3f);
for (int i = 1; i <= n; ++i) {
int u = p[i];
if (i > 1 && b[u] > b[1]) {
continue;
}
if (i == 1) {
f[u] = 0;
} else {
for (int v = u; v; v = fa[v]) {
if (S[v].empty()) {
continue;
}
int l = 1, r = (int)S[v].size() - 1, p = 0;
ll k = qdis(u, v);
while (l <= r) {
int mid = (l + r) >> 1;
if (S[v][mid].y - S[v][mid - 1].y <= (S[v][mid].x - S[v][mid - 1].x) * k) {
p = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
f[u] = min(f[u], S[v][p].y - S[v][p].x * k);
}
}
for (int v = u; v; v = fa[v]) {
node w(-b[u], f[u] + a[u] + b[u] * qdis(u, v));
while ((int)S[v].size() >= 2 && (w - S[v][(int)S[v].size() - 2]) * (S[v].back() - S[v][(int)S[v].size() - 2]) >= 0) {
S[v].pop_back();
}
S[v].pb(w);
}
}
mems(h, 0x3f);
mems(vis, 0);
dfs4(x);
vector<ll> vc;
for (int i = 2; i <= n; ++i) {
vc.pb(h[i]);
}
return vc;
}
// int main() {
// int n;
// scanf("%d", &n);
// vector<ll> a(n);
// vector<int> b(n);
// for (ll &x : a) {
// scanf("%lld", &x);
// }
// for (int &x : b) {
// scanf("%d", &x);
// }
// vector<int> u(n - 1), v(n - 1), w(n - 1);
// for (int i = 0; i < n - 1; ++i) {
// scanf("%d%d%d", &u[i], &v[i], &w[i]);
// }
// auto ans = travel(a, b, u, v, w);
// for (ll x : ans) {
// printf("%lld\n", x);
// }
// return 0;
// }
```