题解:P6670 [清华集训 2016] 汽水

· · 题解

题意

给定一棵 n 个节点的树和一个常数 k,边有边权;求一条路径,使得路径上边权的平均值与 k 的差的绝对值最小,输出这个绝对值下取整后的结果。

解法

令边 e 的边权为 w_e,题目相当于对于所有路径 \mathrm {path},求解

\left|\frac{\sum_{e\in\mathrm{path}}w_e}{|\mathrm {path}|}-k\right|

的最小值。把 k 移进分数里,变为求解

\left|\frac{\sum_{e\in\mathrm{path}}w_e-k}{|\mathrm{path}|}\right|

的最小值,相当于给每个边的边权减去 k 后求路径平均边权的绝对值的最小值。注意到 |\mathrm{path}| 可以表示为 \sum_{e\in\mathrm{path}}1,因此这个式子可以用分数规划的方法求解。考虑二分答案为 w,转为判断是否存在一条路径 \mathrm{path},使得

\left|\frac{\sum_{e\in\mathrm{path}}w_e}{\sum_{e\in\mathrm{path}}1}\right|\le w\\\Longrightarrow~~-w\le\frac{\sum_{e\in\mathrm{path}}w_e}{\sum_{e\in\mathrm{path}}1}\le w\\\Longrightarrow~~\sum_{e\in\mathrm{path}}w_e-w\le 0\le\sum_{e\in\mathrm{path}}w_e+w

考虑点分治,假设现在的重心是 u,对于两个不在 u 同一个孩子子树里的点 v_1,v_2,设它们到根的路径中,w_e+w 的和为 \mathrm{dis}_{v_1}\mathrm{dis}_{v_2}w_e-w 的和为 \mathrm{sid}_{v_1}\mathrm{sid}_{v_2},则 v_1\to\cdots\to u\to\cdots\to v_2 这样一条路径合法当且仅当:

\left\{\begin{matrix}\mathrm{dis}_{v_1}+\mathrm{dis}_{v_2}\ge 0\\\mathrm{sid}_{v_1}+\mathrm{sid}_{v_2}\le 0\end{matrix}\right.

不妨把所有 u 子树里的点 v 按照 \mathrm{dis}_{v} 从小到大排序;那么再枚举 v_1,满足 \mathrm{dis}_{v_1}+\mathrm{dis}_{v_2}\ge 0v_2 的范围总是一段后缀,并且这段后缀随着 u 的增大长度单调不降。现在我们得到了满足 \mathrm{dis}_{v_1}+\mathrm{dis}_{v_2}\ge 0v_2 的范围 [p,q],只需要找到一个 v_2\in[p,q] 使得 \mathrm{sid}_{v_1}+\mathrm{sid}_{v_2}\le 0,并且 v_1,v_2 不在同一子树,找到一个就说明当前的答案 w 是合法的。我们贪心地取出 [p,q]\mathrm{sid}_{v_2} 最小的且和 v_1 不在同一子树的 v_2 进行判断即可。

总结一下,先二分 w,再做点分治;每轮点分治把该层所有的点拿出来按照 \mathrm{dis} 从小到大排序,用双指针维护满足第一个条件的合法的点的选取范围,预处理一个后缀 \mathrm{sid} 最小值和次小值(且次小值所在子树和最小值不同)进行判断。时间复杂度 O(n\log^2n\log V)

代码

这里因为懒直接用的实数二分,少了一些整数时边界的判断。

#include <bits/stdc++.h>
bool MemoryST; using namespace std;
#define ll long long
#define mk make_pair
#define open(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define lowbit(x) ((x) & (-(x)))
#define lson l, mid, rt << 1
#define rson mid + 1, r, rt << 1 | 1
#define BCNT __builtin_popcount
#define cost_time (1e3 * clock() / CLOCKS_PER_SEC) << "ms"
#define cost_space (abs(&MemoryST - &MemoryED) / 1024.0 / 1024.0) << "MB"
const int inf = 0x3f3f3f3f; 
const ll linf = 1e18; 
mt19937 rnd(random_device{}());
template<typename T> void chkmax(T& x, T y) { x = max(x, y); }
template<typename T> void chkmin(T& x, T y) { x = min(x, y); }
template<typename T> T abs(T x) { return (x < 0) ? -x : x; }
const double eps = 1e-8;
bool MemoryED; int main() { // open(data);
    int n; ll K; cin >> n >> K;
    vector<vector<pair<int, ll> > >mp(n);
    auto addEdge = [&](int u, int v, ll w) -> void {
        mp[u].emplace_back(v, w);
    };
    for (int i = 1; i < n; i ++) {
        int u, v; ll w; cin >> u >> v >> w; u --, v --;
        w -= K, addEdge(u, v, w), addEdge(v, u, w);
    } double ans = 0;
    vector<int> siz(n), mxsiz(n); vector<bool> used(n, 0);
    auto getsiz = [&](auto self, int u, int fa) -> void {
        siz[u] = 1;
        for (auto [v, w] : mp[u])
            if (!used[v] && v != fa) self(self, v, u), siz[u] += siz[v];
    }; auto findrt = [&](auto self, int u, int fa, int all, int &rt) -> void {
        mxsiz[u] = all - siz[u];
        for (auto [v, w] : mp[u])
            if (!used[v] && v != fa) self(self, v, u, all, rt), chkmax(mxsiz[u], siz[v]);
        if (rt == -1 || mxsiz[rt] > mxsiz[u]) rt = u;
    }; 
    int _ = 1; for (double l = 0, r = 1e13; _ <= 50 && r - l >= eps; _ ++) {
        double k = (l + r) / 2;
        bool ok = 0; for (int i = 0; i < n; i ++) used[i] = 0;
        vector<pair<pair<double, double>, int> > info;
        auto getdis = [&](auto self, int u, int fa, int top, double cur_dis, double cur_sid) -> void {
            info.emplace_back(mk(cur_dis, cur_sid), top);
            for (auto [v, w] : mp[u])
                if (!used[v] && v != fa) self(self, v, u, top, cur_dis + 1.0 * w + k, cur_sid + 1.0 * w - k);
        }; 
        auto dfz = [&](auto self, int u) -> void {
            if (ok) return ;
            used[u] = 1, getsiz(getsiz, u, -1);
            for (auto [v, w] : mp[u])
                if (!used[v]) getdis(getdis, v, u, v, 1.0 * w + k, 1.0 * w - k);
            info.emplace_back(mk(0, 0), u); 
            sort(info.begin(), info.end()); int len = (int)info.size();
            vector<pair<int, double> > mn(len); vector<pair<int, double> > sec_mn(len);
            mn[len - 1] = mk(info[len - 1].second, info[len - 1].first.second), sec_mn[len - 1] = mk(-1, linf);
            for (int i = len - 2; ~i; i --) {
                mn[i] = mn[i + 1], sec_mn[i] = sec_mn[i + 1];
                if (mn[i].first == info[i].second) {
                    if (info[i].first.second < mn[i].second)
                        mn[i].second = info[i].first.second;
                } else if (info[i].first.second < mn[i].second)
                    sec_mn[i] = mn[i], mn[i] = mk(info[i].second, info[i].first.second);
                else if (info[i].first.second < sec_mn[i].second)
                    sec_mn[i] = mk(info[i].second, info[i].first.second);
            } for (int L = 0, R = len; !ok && L < len; L ++) {
                for (; L < R - 1 && info[L].first.first + info[R - 1].first.first >= -eps; R --);
                if (R <= L) R = L + 1;
                if (R == len) continue;
                if (mn[R].first == info[L].second) {
                    if (sec_mn[R].second + info[L].first.second <= eps) ok = 1;
                } else if (info[L].first.second + mn[R].second <= eps) ok = 1;
            } while (!info.empty()) info.pop_back();
            for (auto [v, w] : mp[u])
                if (!used[v]) {
                    int rt = -1; findrt(findrt, v, u, siz[v], rt);
                    self(self, rt);
                }
        }; int st = -1; findrt(findrt, 0, -1, n, st), dfz(dfz, st);
        if (ok) ans = k, r = k;
        else l = k;
    } cout << (ll)floor(ans) << '\n'; 
    return 0;
}