Solution-P6943

· · 题解

无脑做法。

考虑维护每条边的流量对应的代价凸函数,设 f_{u,i} 表示 u 的父亲边流量为 i,如果是从 u 流向 fa_u,那么流量为正,如果是从 fa_u 流向 u,流量为负数。

我们维护 f 的差分数组 g,转移分为三部分:

综上,我们需要对差分数组进行如下操作:

使用平衡树+启发式合并维护,时间复杂度 \mathcal{O}(V\log^2 V)V\sum a_i

注意维护差分数组的同时,还需要维护函数最左端的取值 val_u

#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define db double
#define ldb long double
#define pb push_back
#define mp make_pair
#define pii pair<int, int>
#define FR first
#define SE second
#define int long long
using namespace std;
inline int read() {
    int x = 0; bool op = 0;
    char c = getchar();
    while(!isdigit(c))op |= (c == '-'), c = getchar();
    while(isdigit(c))x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return op ? -x : x;
}
const int N = 1e6 + 10;
int n, tot;
int a[N], b[N], st[N], val[N], rt[N];
vector<pii> G[N];
struct Node {
    int ls, rs, tg, val, ky, sz;
}nd[N];
mt19937 rnd(time(0));
int New(int x) {
    int cur = ++tot;
    nd[cur].ls = nd[cur].rs = nd[cur].tg = 0;
    nd[cur].sz = 1; nd[cur].val = x; nd[cur].ky = rnd();
    return cur;
}
void mark(int x, int w) {
    if(x == 0)return ;
    nd[x].val += w; nd[x].tg += w;
    return ;
}
void pushdown(int x) {
    mark(nd[x].ls, nd[x].tg);
    mark(nd[x].rs, nd[x].tg);
    nd[x].tg = 0;
    return ;
}
void pushup(int x) {
    if(x == 0)return ;
    nd[x].sz = nd[nd[x].ls].sz + nd[nd[x].rs].sz + 1;
    return ;
}
void merge(int &k, int x, int y) {
    if(x == 0 || y == 0)return k = x + y, void();
    pushdown(x); pushdown(y);
    if(nd[x].ky < nd[y].ky) {k = x; merge(nd[k].rs, nd[x].rs, y);}
    else {k = y; merge(nd[k].ls, x, nd[y].ls);}
    pushup(k);
    return ;
}
void split1(int k, int &x, int &y, int w) {
    if(k == 0)return x = y = 0, void();
    pushdown(k);
    if(nd[nd[k].ls].sz + 1 <= w) {x = k; split1(nd[k].rs, nd[x].rs, y, w - nd[nd[k].ls].sz - 1);}
    else {y = k; split1(nd[k].ls, x, nd[y].ls, w);}
    pushup(x); pushup(y);
    return ;
}
void split2(int k, int &x, int &y, int w) {
    if(k == 0)return x = y = 0, void();
    pushdown(k);
    if(nd[k].val <= w) {x = k; split2(nd[k].rs, nd[x].rs, y, w);}
    else {y = k; split2(nd[k].ls, x, nd[y].ls, w);}
    pushup(x); pushup(y);
    return ;
}
void update(int u, int w) {
    int x = 0, y = 0;
    val[u] += abs(st[u]) * w;
    split1(rt[u], x, y, -st[u]);
    mark(x, -w); mark(y, w);
    merge(rt[u], x, y);
    return ;
}
void ins(int &k, int u) {
    if(u == 0)return ;
    pushdown(u);
    ins(k, nd[u].ls); ins(k, nd[u].rs);
    int x = 0, y = 0;
    nd[u].ls = nd[u].rs = nd[u].tg = 0;
    nd[u].sz = 1;
    split2(k, x, y, nd[u].val);
    merge(x, x, u); merge(k, x, y);
    return ;
}
void print(int u) {
    if(u == 0)return ;
    print(nd[u].ls);
    printf("%d ", nd[u].val);
    print(nd[u].rs);
    return ;
}
void dfs(int u, int fa) {
    for(auto e : G[u]) {
        int v = e.FR, w = e.SE;
        if(v == fa)continue;
        dfs(v, u); update(v, w); 
        if(nd[rt[u]].sz < nd[rt[v]].sz)swap(rt[u], rt[v]);
        ins(rt[u], rt[v]); st[u] += st[v]; val[u] += val[v];
    }
    if(a[u] > b[u]) {
        for(int i = 1; i <= a[u] - b[u]; i++) {
            int x = 0, y = 0;
            split2(rt[u], x, y, 0); 
            merge(x, x, New(0)); merge(rt[u], x, y);
        }
    }
    else st[u] += a[u] - b[u];
    // printf("rt:%d %d %d\n", u, rt[u], st[u]);
    return ;
}
int ans;
void calc_ans(int u) {
    if(u == 0)return ;
    pushdown(u);
    calc_ans(nd[u].ls);
    // printf("%d ", nd[u].val);
    if(st[1] < 0)ans += nd[u].val, st[1]++;
    calc_ans(nd[u].rs);
    return ;
}
signed main() { 
    n = read();
    for(int i = 1; i < n; i++) {
        int u = read(), v = read(), w = read();
        G[u].pb(mp(v, w)); G[v].pb(mp(u, w));
    }
    for(int i = 1; i <= n; i++)a[i] = read(), b[i] = read();
    dfs(1, 0); calc_ans(rt[1]); 
    printf("%lld\n", ans + val[1]);
    return 0;
}