CTSC 2017 网络(树的直径,结论,二分答案,排序,双指针)

· · 题解

CTSC 2017 网络(树的直径,结论,二分答案,排序,双指针)

首先,有个结论是选择的两个端点一定在直径上

于是最后的答案只和一个点到直径某个端点的距离 \text{pre}_i,和它挂着的最长链的长度 \text{val}_i 有关。

假设二分的答案为 x,加的边的长度为 m,对于任意的 j<i,如果满足 \text{pre}_i-\text{pre}_j+\text{val}_i+\text{val}_j>x,那么必须满足存在一个点对 b<a,使得

|\text{pre}_j-\text{pre}_b|+|\text{pre}_i-\text{pre}_a|+\text{val}_i+\text{val}_j+m\leq x

把绝对值拆开,也就是要找到一个 b<a,满足

\begin{aligned} x-m+\text{pre}_a+\text{pre}_b\geq \text{val}_i+\text{pre}_i+\text{val}_j+\text{pre}_j=A\\ x-m-\text{pre}_a+\text{pre}_b\geq \text{val}_i-\text{pre}_i+\text{val}_j+\text{pre}_j=B\\ x-m+\text{pre}_a-\text{pre}_b\geq \text{val}_i+\text{pre}_i+\text{val}_j-\text{pre}_j=C\\ x-m-\text{pre}_a-\text{pre}_b\geq \text{val}_i-\text{pre}_i+\text{val}_j-\text{pre}_j=D \end{aligned}

可以先尝试找到最大的 A,B,C,D 来简化问题。

考虑对每个 i 计算答案,我们需要满足那个限制,于是可以考虑开两个数组,一个以 \text{val}_i+\text{pre}_i 为关键字从小到大排序,一个以 \text{val}_j-\text{pre}_j 为关键字从大到小排序,双指针一下即可得到 A,B,C,D 的最大值。

这样做可能有个问题:排了序之后就不一定满足 j<i 的限制了,但其实是不影响的因为如果存在 j>i 满足那个限制,那么 j<i 时也一定满足那个限制。

找到 A,B,C,D 的最大值之后可以考虑对于每个 a 计算 b 的合法区间,根据单调性维护 4 个指针,然后求一下区间交即可判断。

时间复杂度 \mathcal{O}(n\log V)

#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int maxn = 1e5 + 5;
int n, cnt, fa[maxn], p[maxn], p1[maxn], p2[maxn];
ll m, dis[maxn], mx[maxn], val[maxn], pre[maxn];
vector<pair<int, int>> e[maxn];
bool cmp1(int x, int y) {
    return val[x] + pre[x] < val[y] + pre[y];
}
bool cmp2(int x, int y) {
    return val[x] - pre[x] > val[y] - pre[y];
}
void dfs(int u, int from) {
    fa[u] = from;
    mx[u] = dis[u];
    for(auto v : e[u]) {
        if(v.fi == from) {
            continue;
        }
        dis[v.fi] = dis[u] + v.se;
        dfs(v.fi, u);
        mx[u] = max(mx[u], mx[v.fi]);
    }
}
bool check(ll x) {
    ll A = -inf, B = -inf, C = -inf, D = -inf, mx1 = -inf, mx2 = -inf;
    for(int i = 1, j = 0; i <= cnt; i++) {
        while(j + 1 <= cnt && val[p1[i]] + pre[p1[i]] + val[p2[j + 1]] - pre[p2[j + 1]] > x) {
            j++;
            mx1 = max(mx1, val[p2[j]] + pre[p2[j]]);
            mx2 = max(mx2, val[p2[j]] - pre[p2[j]]);
        }
        A = max(A, val[p1[i]] + pre[p1[i]] + mx1);
        B = max(B, val[p1[i]] - pre[p1[i]] + mx1);
        C = max(C, val[p1[i]] + pre[p1[i]] + mx2);
        D = max(D, val[p1[i]] - pre[p1[i]] + mx2);
        if(i == cnt && j == 0) {
            return 1;
        }
    }
    A += m - x, B += m - x, C += m - x, D += m - x;
    int a = cnt + 1, b = 1, c = 0, d = cnt;
    for(int i = 1; i <= cnt; i++) {
        while(a > 1 && pre[i] + pre[a - 1] >= A) {
            a--;
        }
        while(b <= cnt && -pre[i] + pre[b] < B) {
            b++;
        }
        while(c < cnt && pre[i] - pre[c + 1] >= C) {
            c++;
        }
        while(d >= 1 && -pre[i] - pre[d] < D) {
            d--;
        }
        /*
        [a, i)
        [b, i)
        [1, c]
        [1, d]
        */
        if(max(a, b) <= min(min(c, d), i - 1)) {
            return 1;
        }
    }
    return 0;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    while(cin >> n >> m) {
        if(n == 0 && m == 0) {
            break;
        }
        for(int i = 1; i <= n; i++) {
            e[i].clear();
        }
        for(int i = 1, u, v, w; i < n; i++) {
            cin >> u >> v >> w;
            e[u].eb(v, w);
            e[v].eb(u, w);
        }
        dis[1] = 0;
        dfs(1, 0);
        int S = 1, T = 1;
        for(int i = 1; i <= n; i++) {
            if(dis[i] > dis[S]) {
                S = i;
            }
        }
        dis[S] = 0;
        dfs(S, 0);
        for(int i = 1; i <= n; i++) {
            if(dis[i] > dis[T]) {
                T = i;
            }
        }
        cnt = 0;
        int u = T;
        while(u != S) {
            p[++cnt] = u;
            u = fa[u];
        }
        p[++cnt] = S;
        reverse(p + 1, p + 1 + cnt);
        for(int i = 1; i <= cnt; i++) {
            p1[i] = p2[i] = i;
            pre[i] = dis[p[i]];
            val[i] = 0;
            for(auto v : e[p[i]]) {
                if(v.fi != p[i - 1] && v.fi != p[i + 1]) {
                    val[i] = max(val[i], mx[v.fi] - dis[p[i]]);
                }
            }
        }
        sort(p1 + 1, p1 + 1 + cnt, cmp1);
        sort(p2 + 1, p2 + 1 + cnt, cmp2);
        ll l = 0, r = inf, ans = 0;
        while(l <= r) {
            ll mid = (l + r) >> 1;
            if(check(mid)) {
                ans = mid;
                r = mid - 1;
            }
            else {
                l = mid + 1;
            }
        }
        cout << ans << '\n';
    }
    return 0;
}