Solution-P6943
无脑做法。
考虑维护每条边的流量对应的代价凸函数,设
我们维护
- 合并子树。
- 是一个
(\min,+) 卷积,直接闵可夫斯基和归并差分数组即可。
- 是一个
- 处理当前点贡献的流量。
- 设
t_i=a_i-b_i ,如果t_i>0 ,可以选择流的流量x\in [0,t_i] ,那么操作相当于将f_{u,*} 和f(x)=0,x\in[0,t_i] ,做(\min,+) 卷积,仍然使用闵可夫斯基和归并。如果t_i<0 ,那么操作相当于将f_{u,*} 平移t_i 个单位,那么记录st_u 表示u 的平移量即可。
- 设
- 处理父亲边的代价。
- 操作相当于加上一个函数
f(x)=w|x| ,w 为边权。那么操作即为对差分数组下标\leq 0 的部分加上-w ,下标>0 的部分加上w 。
- 操作相当于加上一个函数
综上,我们需要对差分数组进行如下操作:
- 归并两个差分数组。
- 区间加。
使用平衡树+启发式合并维护,时间复杂度
注意维护差分数组的同时,还需要维护函数最左端的取值
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define db double
#define ldb long double
#define pb push_back
#define mp make_pair
#define pii pair<int, int>
#define FR first
#define SE second
#define int long long
using namespace std;
inline int read() {
int x = 0; bool op = 0;
char c = getchar();
while(!isdigit(c))op |= (c == '-'), c = getchar();
while(isdigit(c))x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return op ? -x : x;
}
const int N = 1e6 + 10;
int n, tot;
int a[N], b[N], st[N], val[N], rt[N];
vector<pii> G[N];
struct Node {
int ls, rs, tg, val, ky, sz;
}nd[N];
mt19937 rnd(time(0));
int New(int x) {
int cur = ++tot;
nd[cur].ls = nd[cur].rs = nd[cur].tg = 0;
nd[cur].sz = 1; nd[cur].val = x; nd[cur].ky = rnd();
return cur;
}
void mark(int x, int w) {
if(x == 0)return ;
nd[x].val += w; nd[x].tg += w;
return ;
}
void pushdown(int x) {
mark(nd[x].ls, nd[x].tg);
mark(nd[x].rs, nd[x].tg);
nd[x].tg = 0;
return ;
}
void pushup(int x) {
if(x == 0)return ;
nd[x].sz = nd[nd[x].ls].sz + nd[nd[x].rs].sz + 1;
return ;
}
void merge(int &k, int x, int y) {
if(x == 0 || y == 0)return k = x + y, void();
pushdown(x); pushdown(y);
if(nd[x].ky < nd[y].ky) {k = x; merge(nd[k].rs, nd[x].rs, y);}
else {k = y; merge(nd[k].ls, x, nd[y].ls);}
pushup(k);
return ;
}
void split1(int k, int &x, int &y, int w) {
if(k == 0)return x = y = 0, void();
pushdown(k);
if(nd[nd[k].ls].sz + 1 <= w) {x = k; split1(nd[k].rs, nd[x].rs, y, w - nd[nd[k].ls].sz - 1);}
else {y = k; split1(nd[k].ls, x, nd[y].ls, w);}
pushup(x); pushup(y);
return ;
}
void split2(int k, int &x, int &y, int w) {
if(k == 0)return x = y = 0, void();
pushdown(k);
if(nd[k].val <= w) {x = k; split2(nd[k].rs, nd[x].rs, y, w);}
else {y = k; split2(nd[k].ls, x, nd[y].ls, w);}
pushup(x); pushup(y);
return ;
}
void update(int u, int w) {
int x = 0, y = 0;
val[u] += abs(st[u]) * w;
split1(rt[u], x, y, -st[u]);
mark(x, -w); mark(y, w);
merge(rt[u], x, y);
return ;
}
void ins(int &k, int u) {
if(u == 0)return ;
pushdown(u);
ins(k, nd[u].ls); ins(k, nd[u].rs);
int x = 0, y = 0;
nd[u].ls = nd[u].rs = nd[u].tg = 0;
nd[u].sz = 1;
split2(k, x, y, nd[u].val);
merge(x, x, u); merge(k, x, y);
return ;
}
void print(int u) {
if(u == 0)return ;
print(nd[u].ls);
printf("%d ", nd[u].val);
print(nd[u].rs);
return ;
}
void dfs(int u, int fa) {
for(auto e : G[u]) {
int v = e.FR, w = e.SE;
if(v == fa)continue;
dfs(v, u); update(v, w);
if(nd[rt[u]].sz < nd[rt[v]].sz)swap(rt[u], rt[v]);
ins(rt[u], rt[v]); st[u] += st[v]; val[u] += val[v];
}
if(a[u] > b[u]) {
for(int i = 1; i <= a[u] - b[u]; i++) {
int x = 0, y = 0;
split2(rt[u], x, y, 0);
merge(x, x, New(0)); merge(rt[u], x, y);
}
}
else st[u] += a[u] - b[u];
// printf("rt:%d %d %d\n", u, rt[u], st[u]);
return ;
}
int ans;
void calc_ans(int u) {
if(u == 0)return ;
pushdown(u);
calc_ans(nd[u].ls);
// printf("%d ", nd[u].val);
if(st[1] < 0)ans += nd[u].val, st[1]++;
calc_ans(nd[u].rs);
return ;
}
signed main() {
n = read();
for(int i = 1; i < n; i++) {
int u = read(), v = read(), w = read();
G[u].pb(mp(v, w)); G[v].pb(mp(u, w));
}
for(int i = 1; i <= n; i++)a[i] = read(), b[i] = read();
dfs(1, 0); calc_ans(rt[1]);
printf("%lld\n", ans + val[1]);
return 0;
}