P11237 题解
首先考虑
现在再考虑 为啥不用判断会不会在链上追到,会出负数?因为这样一定不优。
总复杂度 我写了5k,好像双 log 更好写
#include "police.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define db double
const int N=1e5+10;
int n,q;
vector<pair<int,ll> > e[N];
ll mx[N],se[N],disrt[N];
int mxp[N],sep[N];
int sum[N],dep[N],Fa[N];
int son[N],top[N],in[N],p[N],timer;
void dfs1(int x,int fa) {
sum[x]=1;
Fa[x]=fa;
dep[x]=dep[fa]+1;
for(auto v:e[x]) if (v.first!=fa) {
dfs1(v.first,x);
sum[x]+=sum[v.first];
if (sum[v.first]>sum[son[x]]) son[x]=v.first;
}
}
void dfs2(int x,int cap) {
top[x]=cap;
p[in[x]=++timer]=x;
if(son[x]) dfs2(son[x],cap);
for(auto v:e[x]) if(v.first!=son[x]&&v.first!=Fa[x]) {
dfs2(v.first,v.first);
}
}
void upd(int x,ll d,int c) {
if(d>mx[x]) {
se[x]=mx[x],sep[x]=mxp[x];
mx[x]=d,mxp[x]=c;
} else if(d>se[x]) {
se[x]=d,sep[x]=c;
}
}
void dfs3(int x) {
for(auto v:e[x]) if(v.first!=Fa[x]) {
disrt[v.first]=disrt[x]+v.second;
dfs3(v.first);
upd(x,mx[v.first]+v.second,v.first);
}
}
void dfs4(int x,ll d) {
upd(x,d,Fa[x]);
for(auto v:e[x]) if(v.first!=Fa[x]) {
dfs4(v.first,max(d,mxp[x]!=v.first? mx[x]:se[x])+v.second);
}
}
bool check(int x,int b,ll a1,ll a2,ll v1,ll v2) {
ll di=(mxp[x]==b? se[x]:mx[x]);
// cerr<<di<<'\n';
return (1.0*di+1.0*a1)/(1.0*v1)<(1.0*a1-1.0*a2)/(1.0*v1-1.0*v2);
}
array<ll,2> query(int P,int T,ll v1,ll v2) {
vector<pair<int,int> > up,dn;
ll d=0,g,u=P,v=T;
while(top[u]!=top[v]) {
if(dep[top[u]]>dep[top[v]]) {
up.emplace_back(in[u],in[top[u]]);
u=Fa[top[u]];
} else {
dn.emplace_back(in[top[v]],in[v]);
v=Fa[top[v]];
}
}
if(dep[u]<dep[v]) g=u,dn.emplace_back(in[u],in[v]);
else g=v,up.emplace_back(in[u],in[v]);
d=disrt[P]+disrt[T]-disrt[g]*2;
reverse(dn.begin(),dn.end());
if(v1<=v2) {
int lst=P;
for(auto li:up) {
int x=p[li.second],b=(li.second==li.first? lst:p[li.second+1]);
ll a1=disrt[P]-disrt[x],a2=d-a1;
if((1.0*a1)/(1.0*v1)<=(1.0*a2)/(1.0*v2)) {
lst=x;
continue;
}
int l=li.first,r=li.second,res=li.second;
while(l>=r) {
int mid=(l+r+1)/2;
int x=p[mid];//,b=(mid==li.first?lst:p[mid+1]);
ll a1=disrt[P]-disrt[x],a2=d-a1;
if((1.0*a1)/(1.0*v1)>(1.0*a2)/(1.0*v2)) res=mid,r=mid+1;
else l=mid-1;
}
x=p[res],b=(res==li.first?lst:p[res+1]);
a1=disrt[P]-disrt[x],a2=d-a1;
ll di=(mxp[x]==b?se[x]:mx[x]);
return {a1+di,v1};
}
for(auto li:dn) {
int x=p[li.second],b=(li.second==li.first? lst:p[li.second-1]);
ll a2=disrt[T]-disrt[x],a1=d-a2;
if((1.0*a1)/(1.0*v1)<=(1.0*a2)/(1.0*v2)) {
lst=x;
continue;
}
int l=li.first,r=li.second,res=li.second;
while(l<=r) {
int mid=(l+r)/2;
int x=p[mid];//,b=(mid==li.first?lst:p[mid-1]);
ll a2=disrt[T]-disrt[x],a1=d-a2;
if((1.0*a1)/(1.0*v1)>(1.0*a2)/(1.0*v2)) res=mid,r=mid-1;
else l=mid+1;
}
x=p[res],b=(res==li.first?lst:p[res-1]);
a2=disrt[T]-disrt[x],a1=d-a2;
ll di=(mxp[x]==b?se[x]:mx[x]);
return {a1+di,v1};
}
} else {
int lst=P;
ll al1=0,al2=0;
for(auto li:up) {
int x=p[li.second],b=(li.first==li.second? lst:p[li.second+1]);
ll a1=disrt[P]-disrt[x],a2=d-a1;
if(!check(x,b,a1,a2,v1,v2)) {
lst=x,al1=a1,al2=a2;
continue;
}
int l=li.first,r=li.second,res=li.second;
while(l>=r) {
int mid=(l+r+1)/2;
int x=p[mid],b=(mid==li.first?lst:p[mid+1]);
ll a1=disrt[P]-disrt[x],a2=d-a1;
if(check(x,b,a1,a2,v1,v2)) res=mid,r=mid+1;
else l=mid-1;
}
x=p[res],b=(res==li.first?lst:p[res+1]);
if(b!=lst) {
al1=disrt[P]-disrt[b],al2=d-al1;
}
a1=disrt[P]-disrt[x],a2=d-a1;
ll di=(mxp[x]==b?se[x]:mx[x]);
db res1=(1.0*di+1.0*a1)/(1.0*v1),res2=(1.0*al1-1.0*al2)/(1.0*v1-1.0*v2);
// cerr<<p[li.first]<<' '<<p[li.second]<<'\n';
// cerr<<x<<' '<<res1<<' '<<res2<<'\n';
if(res1>res2) return {di+a1,v1};
else return {al1-al2,v1-v2};
}
for(auto li:dn) {
int x=p[li.second],b=(li.first==li.second? lst:p[li.second-1]);
ll a2=disrt[T]-disrt[x],a1=d-a2;
if(!check(x,b,a1,a2,v1,v2)) {
lst=x,al1=a1,al2=a2;
continue;
}
int l=li.first,r=li.second,res=li.second;
while(l<=r) {
int mid=(l+r)/2;
int x=p[mid],b=(mid==li.first?lst:p[mid-1]);
ll a2=disrt[T]-disrt[x],a1=d-a2;
if(check(x,b,a1,a2,v1,v2)) res=mid,r=mid-1;
else l=mid+1;
}
x=p[res],b=(res==li.first?lst:p[res-1]);
if(b!=lst) {
al2=disrt[T]-disrt[b],al1=d-al2;
}
a2=disrt[T]-disrt[x],a1=d-a2;
ll di=(mxp[x]==b?se[x]:mx[x]);
db res1=(1.0*di+1.0*a1)/(1.0*v1),res2=(1.0*al1-1.0*al2)/(1.0*v1-1.0*v2);
// cerr<<p[li.first]<<' '<<p[li.second]<<'\n';
// cerr<<x<<' '<<res1<<' '<<res2<<'\n';
if(res1>res2) return {di+a1,v1};
else return {al1-al2,v1-v2};
}
return {d,v1-v2};
}
return {-1,-1};
}
std::vector<std::array<long long, 2>> police_thief(std::vector<int> A, std::vector<int> B, std::vector<int> D,
std::vector<int> P, std::vector<int> V1, std::vector<int> T, std::vector<int> V2){
n=A.size()+1,q=(int)P.size();
for(int i=0;i<n-1;i++) {
e[A[i]+1].emplace_back(B[i]+1,D[i]);
e[B[i]+1].emplace_back(A[i]+1,D[i]);
}
dfs1(1,0);
dfs2(1,1);
dfs3(1);
dfs4(1,0);
// for(int i=1;i<=n;i++) cerr<<in[i]<<' ';
// cerr<<'\n';
std::vector<std::array<long long, 2>> C(q);
for(int te=0;te<q;te++) {
// cerr<<"TE\n";
C[te]=query(P[te]+1,T[te]+1,V1[te],V2[te]);
}
// cerr<<mx[8]<<"\n";
return C;
}