题解:P12491 [集训队互测 2024] 串联

· · 题解

一道比较好的点分治板子题。

Solution

看到路径问题,很容易想到点分治。

套路的,考虑在一个分治中心 x,先求出 ix 路径上 A_i = \min\{a_j\},以及 B_i = \sum_{j \neq x}{b_j}。然后我们将所有点按 A_i 从大到小排序,这样可以解开 \min 的限制。那么考虑对于一个点 i,即要求出前面 B_j \ge \lceil\frac{V}{A_i}\rceil - B_i - b_x 中最小的 B_j。考虑将 B 离散化,然后到树状数组上后缀 \min 即可。对于维护不同子树,那就维护最小值和次小值,并要求其所在子树不同。

还有一点要注意:特判单个点都满足 a_i \times b_i \ge V 的情况。

时间复杂度 O(n\log_2^2n)

:::success[AC Code]

写的有点丑,但十分易懂。

#include <bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define mp(Tx, Ty) make_pair(Tx, Ty)
#define For(Ti, Ta, Tb) for(auto Ti = (Ta); Ti <= (Tb); Ti++)
#define Dec(Ti, Ta, Tb) for(auto Ti = (Ta); Ti >= (Tb); Ti--)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define range(Tx) begin(Tx),end(Tx)
const int N = 2e5 + 5, INF = 1e9;
const long long INFl = 1e18;
int n;
long long V;
int h[N], e[N * 2], ne[N * 2], idx;
long long a[N], b[N];
void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
long long ans = 1e18;
bool vis[N];
int sz[N];
int get_sz(int x, int fa) {
    if (vis[x]) return 0;
    sz[x] = 1;
    for (int i = h[x]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        sz[x] += get_sz(j, x);
    }
    return sz[x];
}
int get_wc(int x, int fa, int tot, int &wc) {
    if (vis[x]) return 0;
    int maxn = 0;
    sz[x] = 1;
    for (int i = h[x]; ~i ;i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        int d = get_wc(j, x, tot, wc);
        maxn = max(maxn, d);
        sz[x] += d; 
    }
    if (max(maxn, tot - sz[x]) <= tot / 2) wc = x;
    return sz[x];
}
vector<pair<pair<long long, long long>, int> > q;
int Id[N];
void get_dist(int x, int fa, long long A, long long B, int F) {
    if (vis[x]) return;
    q.push_back(mp(mp(A, B), x));
    Id[x] = F;
    for (int i = h[x]; ~i ;i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        get_dist(j, x, min(A, a[j]), B + b[j], (fa == 0 ? j : F));
    }
}
bool cmp(pair<pair<long long, long long>, int> x, pair<pair<long long, long long>, int> y) {
    return x.x.x > y.x.x;
}
pair<int, int> tr[N], tr1[N];
int lowbit(int x) {
    return x & -x;
}
void update(pair<int, int> &a, pair<int, int> &b, pair<int, int> c) {
    if (a.x > c.x) {
        if (a.y == c.y) a = c;
        else b = a, a = c;
    } else {
        if (b.x > c.x && c.y != a.y) b = c;
    }
}
void add(int x, int y, int z) {
    for (int i = x; i; i -= lowbit(i)) update(tr[i], tr1[i], mp(y, z));
}
int sum(int x, int y) {
    pair<int, int> maxn = mp(INF, 0), maxnn = mp(INF, 0);
    for (int i = x; i <= n; i += lowbit(i)) update(maxn, maxnn, tr[i]), update(maxn, maxnn, tr1[i]);
    if (maxn.y == y) return maxnn.x;
    return maxn.x;
}
void del(int x) {
    for (int i = x; i; i -= lowbit(i)) tr[i] = tr1[i] = mp(INF, 0);
}
long long bb[N];
int k;
void work(int x) {
    if (vis[x]) return;
    q.clear(); 
    get_wc(x, 0, get_sz(x, 0), x);
    get_dist(x, 0, a[x], 0, x);
    k = 0;
    for (auto i : q) bb[++k] = i.x.y;
    sort(bb + 1, bb + k + 1);
    k = unique(bb + 1, bb + k + 1) - bb - 1; 
    for (auto &i : q) i.x.y = lower_bound(bb + 1, bb + k + 1, i.x.y) - bb;
    sort(range(q), cmp);
    for (auto i : q) {
        long long lim = (max(0ll, V - i.x.x * (bb[i.x.y] + b[x])) + i.x.x - 1) / i.x.x;
        int W = lower_bound(bb + 1, bb + k + 1, lim) - bb; 
        int w = sum(W, Id[i.y]);
        if (w != INF) ans = min(ans, bb[i.x.y] + bb[w] + b[x]);
        add(i.x.y, i.x.y, Id[i.y]);
    }
    for (auto i : q) del(i.x.y);
    vis[x] = 1;
    for (int i = h[x]; ~i; i = ne[i]) work(e[i]);
}
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    memset(h, -1, sizeof(h));
    cin >> n >> V;
    For(i, 1, n) cin >> a[i] >> b[i];
    For(i, 1, n) if (a[i] * b[i] >= V) ans = min(ans, b[i]);
    For(i, 1, n) tr[i] = tr1[i] = mp(INF, 0);
    For(i, 1, n - 1) {
        int a, b;
        cin >> a >> b;
        add(a, b), add(b, a);
    }
    work(1);
    cout << ans;
    return 0;
} 

:::