P5306 [COCI2019] Transport 题解
这是一篇点分治 + 双指针的题解。
计算路径
设点分治中,当前找到的根为
分两种情况考虑合法性:
-
从
i 走到rt :显然需要对于任意在i 到rt 路径上的节点j ,满足w_i-w_j \ge d_i - d_j 。稍加变换得:w_i-d_i-(w_j-d_j)\ge 0 。此时只需要维护、记录最大的w_j-d_j 即可判断从i 到rt 的合法性。 -
从
rt 走到i :这种情况,路径的起点不一定是rt ,可以从其他子树中的一个节点k\rightarrow rt \rightarrow i 。设k\rightarrow rt 后还剩下的油量为x (直接从rt 出发那么x 为0 ),则rt\rightarrow i 路径上任意一点j ,需要满足x+w_{fa_j} \ge d_j 。依旧是稍加变换:x+(w_{fa_j}-d_j) \ge 0 ,维护、记录路径上最小的w_{fa_j}-d_j 即可。
统计答案
遍历所有子树,把所有路径都记录下来。
从
从
分别将两种情况按记录的权值排序,用
此时对于所有
注意要将
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 100005;
int rd(){
int x = 0;char ch = getchar();
while(ch < '0' || ch > '9')ch = getchar();
while(ch >= '0' && ch <= '9'){x = x * 10 + ch - '0';ch = getchar();}
return x;
}
ll s1[N],s2[N],q1[N],q2[N],w[N],d[N],ans;
int tot,head[N],ver[2*N],edge[2*N],Next[2*N];
int n,a[N],rt,sum,s[N],Mas[N],vis[N],t1,t2,h1,h2;
void add(int x,int y,int z){
tot++;
ver[tot] = y;
edge[tot] = z;
Next[tot] = head[x];
head[x] = tot;
}
void Grt(int fa,int x){
s[x] = 1,Mas[x] = 0;
for(int i = head[x];i;i = Next[i]){
int y = ver[i];
if(y == fa || vis[y])continue;
Grt(x,y);
s[x] += s[y];
Mas[x] = max(Mas[x],s[y]);
}
Mas[x] = max(Mas[x],sum - s[x]);
if(!rt || Mas[x] < Mas[rt])rt = x;
}
void dfs(int fa,int x,ll Max,ll Min){//Max 最大 w[j] - d[j];Min 最小 w[fa[j]]-d[j]
if(w[x] - d[x] >= Max)//合法才记录
s1[++t1] = w[x] - d[x];//i到rt
s2[++t2] = Min;//rt到i,目前不能确定是否合法,都记录
for(int i = head[x];i;i = Next[i]){
int y = ver[i],z = edge[i];
if(y == fa || vis[y])continue;
d[y] = d[x] + z,w[y] = w[x] + a[y];
dfs(x,y,max(Max,w[x] - d[x]),min(Min,w[x] - d[y]));//注意更新 Max 和 Min
}
}
void calc(int x){
int L;
h1 = 0,h2 = 0;
d[x] = 0,w[x] = a[x];
for(int i = head[x];i;i = Next[i]){
int y = ver[i],z = edge[i];
if(vis[y])continue;
t1 = 0,t2 = 0;
d[y] = z,w[y] = w[x] + a[y];//初始化d[y]和w[y]
dfs(x,y,w[x] - d[x],w[x] - d[y]);
sort(s1 + 1,s1 + 1 + t1);//s1、s2记录当前子树内的路径
sort(s2 + 1,s2 + 1 + t2);
L = 1;
for(int i = t2;i >= 1;i--){//注意倒序枚举
while(L <= t1 && s1[L] + s2[i] - a[x] < 0)L++;
ans -= t1 - L + 1;//减去同一子树内的答案
}
for(int i = 1;i <= t1;i++)
q1[++h1] = s1[i];
for(int i = 1;i <= t2;i++)
q2[++h2] = s2[i];//q1和q2记录所有路径
}
sort(q1 + 1,q1 + 1 + h1);
sort(q2 + 1,q2 + 1 + h2);
L = 1;
for(int i = h2;i >= 1;i--){
if(q2[i] >= 0)//rt到i直接从rt出发
ans++;
while(L <= h1 && q1[L] + q2[i] - a[x] < 0)L++;//i到rt,rt到i都加了a[x],减掉一次
ans += h1 - L + 1;
}
ans += h1;//i出发到rt结束的情况
}
void solve(int x){
vis[x] = 1,calc(x);
for(int i = head[x];i;i = Next[i]){
int y = ver[i],z = edge[i];
if(vis[y])continue;
rt = 0,sum = s[y];
Grt(x,y),solve(rt);
}
}
int main(){
n = rd();
for(int i = 1;i <= n;i++)
a[i] = rd();
for(int i = 1;i < n;i++){
int x = rd(),y = rd(),z = rd();
add(x,y,z);
add(y,x,z);
}
rt = 0,sum = n;
Grt(0,1),solve(rt);
printf("%lld\n",ans);
return 0;
}