题解:P2680 [NOIP2015 提高组] 运输计划
题目思路
前置知识:最近公共祖先,树上差分,二分。
首先我们看到题目中的图保证联通且有
根据题意,找所有运输航道至少要多长时间,即要使花费时间最长的那个运输航线花费的时间最小,就是最大答案的最小值,我们就可以考虑对答案进行二分,然后进行贪心的 check。
我们怎么对当前的答案进行 check 呢?首先我们可以找出所有运输航线的花费时间,对于小于等于当前答案的运输航线,我们可以不进行考虑。因为此时就算开一条虫洞,这条运输航线也一定小于等于当前答案。
对于大于当前答案的运输航线,我们一定要在这些运输航线的公共边上开虫洞,如果我们在其他边上开虫洞,那么一定会存在运输航线花费时间还大于当前答案,不满足条件。现在问题转化为怎么快速找到公共边。这里注意,公共边可能存在多条。
大于当前答案的运输航线有
我们在处理边权的时候也将边权下放到点上,同理从叶子向上累加,对于
判断当前公共边作为虫洞是否可行,就找这
但是这样的话我们在 check 中的循环还计算了 lca,时间复杂度为
假设树为题目样例,当前二分答案为
Code
#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
const int N = 3e5 + 10;
int n, q, cnt, head[N];
int a[N], b[N], disab[N], lcaab[N];
int dis[N], dep[N], pre[N][21];
int run[N];
struct node {
int to, next, w;
} e[N << 1];
void addi(int u, int v, int w) {
e[++ cnt] = {v, head[u], w};
head[u] = cnt;
}
void add(int u, int v, int w) {
addi(u, v, w); addi(v, u, w);
}
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
pre[u][0] = fa;
for(int i = 1; i <= 20; i ++) {
pre[u][i] = pre[pre[u][i - 1]][i - 1];
}
for(int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if(v == fa) continue;
dis[v] = e[i].w + dis[u];
dfs(v, u);
}
}
int lca(int u, int v) {
if(u == v) return u;
if(dep[u] < dep[v]) swap(u, v);
for(int i = 20; i >= 0; i --) {
if(dep[v] <= dep[u] - (1 << i)) u = pre[u][i];
}
if(u == v) return u;
for(int i = 20; i >= 0; i --) {
if(pre[u][i] == pre[v][i]) continue;
u = pre[u][i], v = pre[v][i];
}
return pre[u][0];
}
void dfs2(int u, int fa) {
for(int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if(v == fa) continue;
dfs2(v, u);
run[u] += run[v];
}
}
bool check(int x) {
for(int i = 1; i <= n; i ++) run[i] = 0;
int tot = 0, maxdis = 0;
for(int i = 1; i <= q; i ++) {
maxdis = max(maxdis, disab[i]);
if(disab[i] > x) {
run[a[i]] ++, run[b[i]] ++;
run[lcaab[i]] -= 2;
tot ++;
}
}
if(tot == 0) return 1;
dfs2(1, 0);
for(int i = 1; i <= n; i ++) {
int w = dis[i] - dis[pre[i][0]];
if(run[i] == tot && maxdis - w <= x) {
return 1;
}
}
return 0;
}
inline int rd(){
int w = 1, x = 0;
char c = getchar();
while(c < 48 || c > 57) {
if(c == 45) w *= -1;
c = getchar();
}
while(c >= 48 && c <= 57) {
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return x * w;
}
int main() {
n = rd(), q = rd();
for(int i = 1; i < n; i ++) {
int u = rd(), v = rd(), w = rd();
add(u, v, w);
}
dfs(1, 0);
for(int i = 1; i <= q; i ++) {
a[i] = rd(), b[i] = rd();
lcaab[i] = lca(a[i], b[i]);
disab[i] = dis[a[i]] + dis[b[i]] - 2 * dis[lcaab[i]];
}
int l = 0, r = 1e9, mid, ans = 0;
while(l <= r) {
mid = (l + r) >> 1;
if(check(mid)) r = mid - 1, ans = mid;
else l = mid + 1;
}
cout << ans << "\n";
return 0;
}