题解 P9476【地铁】
假如没有地铁,那么每条边对答案的贡献为
现在有了地铁,使得答案减少了
假设一个人从
其中
上式也可以改写为:
也就是说,在这个人移动的过程中,经过的地铁的每条边对
考虑每条边和每个点产生贡献的次数。
每条边产生贡献,当且仅当它被经过。因此次数等于
每个点产生贡献,当且仅当它作为“内点”被经过。由于地铁是一条链,在链上,与之相邻的两条边均被经过。因此次数等与
因此对
可以使用树形 DP 求出
如果树根为
转移有两种可能性:
一、另一端就是
即
其中
二、另一端不是
即
如图。
假设
如果
如果
如图。
如果暴力枚举
如果有
然后将
由
由于
时间复杂度为
代码:
#include <bits/stdc++.h>
using namespace std;
const int _ = 1e5 + 10;
const int __ = 2e5 + 10;
int id, n, t, e, hd[_], nx[__], to[__], ln1[__], ln2[__];
long long siz[_], N;
__int128 dp[_];
inline void add(int u, int v, int w1, int w2) {
e++;
nx[e] = hd[u];
to[e] = v;
ln1[e] = w1;
ln2[e] = w2;
hd[u] = e;
}
__int128 sum, dif;
void dfs1(int x, int f) {
for (int i = hd[x]; i; i = nx[i]) {
int y = to[i];
if (y != f) {
dfs1(y, x);
siz[x] += siz[y];
sum += __int128(siz[y]) * (N - siz[y]) * (ln1[i]);
}
}
}
int m;
struct node {
__int128 x;
__int128 y;
} arr[_];
int l, r, q[_];
inline bool cmp(node a, node b) {
if (a.x == b.x) return (a.y > b.y);
return (a.x < b.x);
}
inline bool eqn(node a, node b) {
return (a.x == b.x);
}
inline __float128 slope(node a, node b) {
return ((__float128)(b.y - a.y) / (__float128)(b.x - a.x));
}
void dfs2(int x, int f, __int128 z) {
dp[x] = z;
for (int i = hd[x]; i; i = nx[i]) {
int y = to[i];
if (y != f) {
dfs2(y, x, __int128(siz[y]) * (N - siz[y]) * (ln1[i] - ln2[i] - t));
if (f) dp[x] = max(dp[x], dp[y] + z + __int128(siz[y]) * (N - siz[x]) * (t));
dif = max(dif, dp[y]);
}
}
m = 0;
for (int i = hd[x]; i; i = nx[i]) {
int y = to[i];
if (y != f) {
m++;
arr[m].x = siz[y];
arr[m].y = dp[y];
}
}
sort(arr+1, arr+m+1, cmp);
for (int i = 1; i < m; i++) {
if (arr[i].x == arr[i+1].x) {
dif = max(dif, arr[i].y + arr[i+1].y + t * arr[i].x * arr[i+1].x);
}
}
m = unique(arr+1, arr+m+1, eqn) - arr - 1;
l = r = 1;
q[1] = 1;
for (int i = 2; i <= m; i++) {
while (r > l && slope(arr[q[l]], arr[q[l+1]]) > (-t * arr[i].x)) l++;
dif = max(dif, arr[i].y + arr[q[l]].y + t * arr[i].x * arr[q[l]].x);
while (r > l && slope(arr[q[r]], arr[i]) > slope(arr[q[r-1]], arr[i])) r--;
q[++r] = i;
}
}
int main() {
cin >> id >> n >> t;
for (int i = 1; i <= n; i++) {
cin >> siz[i];
N += siz[i];
}
for (int i = 1; i < n; i++) {
int u, v, w1, w2;
cin >> u >> v >> w1 >> w2;
add(u, v, w1, w2);
add(v, u, w1, w2);
}
dfs1(1, 0);
dfs2(1, 0, __int128(0));
__int128 ans = sum - dif;
string str;
while (ans) {
str = (char)((ans % 10) + 48) + str;
ans /= 10;
}
cout << str << endl;
return 0;
}