P8063 [BalkanOI 2012] Shortest paths 题解

· · 题解

模板题。

显然该问题可以规约到删边最短路,而删边最短路显然可以做到 O((n+m)\log m) 的复杂度,因此此题显然可以做到至少 O((n+m)\log m)。下面简单介绍一下删边最短路做法。

首先对起点和终点分别跑最短路,并记录最短路的前驱,即最短路树。我们标记 S\rightarrow T 的最短路为关键的,那么如果删去的边是非关键的,答案显然为 S\rightarrow T 最短路,直接输出即可。

否则新的最短路一定会与原最短路重叠且仅重叠恰好一段前缀和一段后缀,否则调整后一定不劣。那我们可以枚举每一条非关键边,先考虑新最短路怎么经过 u\rightarrow v,一定是先走 S\rightarrow u 最短路,再经过这条边,最后再走 v\rightarrow T 的最短路,计算出这条路径的权值之后,只需标记在其与原最短路不交的那一段连续区间即可。反过来是一样的。寻找路径交可以两遍广搜,标记可以直接线段树。

所有相关正确性证明都可以从最短路性质来调整考虑,复杂度瓶颈为跑两次单源最短路与标记 O(m) 个区间,为 O((n+m)\log n)

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#include <unordered_map>
#define int long long
using namespace std;

int n,m,cnt,S,T;
int a[100005];
int h[100005];
int uu[200005];
int vv[200005];
int ww[200005];
int d1[100005];
int d2[100005];
int p1[100005];
int p2[100005];
int pp1[100005];
int pp2[100005];
int vis[200005];
vector <pair <int,int>> g[100005];
unordered_map <int,int> mp[100005];

inline void in(int &n){
    n=0;
    char c=getchar();
    while(c<'0' || c>'9') c=getchar();
    while(c>='0'&&c<='9') n=n*10+c-'0',c=getchar();
    return ;
}

inline void dij(int S,int d[],int p[]){
    priority_queue <pair <int,int>> q;
    for(int i=1;i<=n;i++) d[i]=1e9,vis[i]=0;
    d[S]=0;
    q.push({0,S});
    while(!q.empty()){
        int u=q.top().second;
        q.pop();
        if(vis[u]) continue;
        vis[u]=1;
        for(auto tmp:g[u]){
            int v=tmp.first,w=ww[tmp.second];
            if(d[v]>d[u]+w){
                d[v]=d[u]+w;
                p[v]=tmp.second;
                q.push({-d[v],v});
            }
        }
    }
    return ;
}

inline void find(int S,int p[],int pp[]){
    queue <int> q;
    q.push(S);
    while(!q.empty()){
        int u=q.front();
        q.pop();
        if(vis[u]) pp[u]=u;
        for(auto tmp:g[u]){
            int v=tmp.first;
            if(p[v]!=tmp.second) continue;
            pp[v]=pp[u];
            q.push(v);
        }
    }
    return ;
}

int mn[400005];
int mnn[100005];

inline void upd(int u,int l,int r,int L,int R,int x){
    if(L<=l&&r<=R){mn[u]=min(mn[u],x);return ;}
    int mid=(l+r)>>1;
    mn[u<<1]=min(mn[u<<1],mn[u]);
    mn[u<<1|1]=min(mn[u<<1|1],mn[u]);
    if(L<=mid) upd(u<<1,l,mid,L,R,x);
    if(R>mid) upd(u<<1|1,mid+1,r,L,R,x);
    return ;
}

inline void final(int u,int l,int r){
    if(l==r){mnn[l]=mn[u];return ;}
    int mid=(l+r)>>1;
    mn[u<<1]=min(mn[u<<1],mn[u]);
    mn[u<<1|1]=min(mn[u<<1|1],mn[u]);
    final(u<<1,l,mid);
    final(u<<1|1,mid+1,r);
    return ;
}

inline int get(int id){
    if(!vis[id]) return d1[T];
    return mnn[vis[id]];
}

signed main(){
    in(n),in(m),in(S),in(T);
    for(int i=1;i<=m;i++){
        in(uu[i]),in(vv[i]),in(ww[i]);
        mp[uu[i]][vv[i]]=i;
        mp[vv[i]][uu[i]]=i;
        g[uu[i]].push_back({vv[i],i});
        g[vv[i]].push_back({uu[i],i});
    }
    dij(S,d1,p1),dij(T,d2,p2);
    memset(vis,0,sizeof(vis));
    int U=T;
    while(U) vis[U]=1,a[++cnt]=U,U=uu[p1[U]]==U?vv[p1[U]]:uu[p1[U]];
    find(S,p1,pp1),find(T,p2,pp2);
    reverse(a+1,a+1+cnt);
    memset(vis,0,sizeof(vis));
    for(int i=1;i<=cnt;i++) h[a[i]]=i,vis[p1[a[i]]]=i;
    memset(mn,0x3f,sizeof(mn));
    for(int i=1;i<=m;i++){
        if(vis[i]) continue;
        int u=uu[i],v=vv[i],w=ww[i];
        if(pp1[u]&&pp2[v]){
            int l=h[pp1[u]]+1,r=h[pp2[v]],W=d1[u]+d2[v]+w;
            upd(1,1,cnt,l,r,W);
        }
        swap(u,v);
        if(pp1[u]&&pp2[v]){
            int l=h[pp1[u]]+1,r=h[pp2[v]],W=d1[u]+d2[v]+w;
            upd(1,1,cnt,l,r,W);
        }
    }
    final(1,1,cnt);
    int q,las;
    in(q),in(las);
    for(int i=1;i<q;i++){
        int x,y;
        in(x);
        y=mp[las][x];
        printf("%lld\n",get(y)>1e9?-1:get(y));
        las=x;
    }

    return 0;
}