题解 P6880【[JOI 2020 Final] オリンピックバス】

· · 题解

这里给出一篇更为通俗易懂的题解,算是对其它题解的补充说明吧。

思路

如果暴力枚举翻转每条边,每个进行一遍 dijkstra,复杂度 O(n^2m),需要优化。

本题的优化方法是减少 dij 的使用次数。因为许多边的翻转其实不影响最短路的长度,此时就不需要再调用 dij 求一边长度了,可以节省时间。

那么边可以分为两类,一种边翻转后对最短路长度产生影响,称为“关键边”,否则不产生影响就不是关键边。

对于每个关键边,需要修改一下边并重新调用 dij 计算最短路,因为 dijn^2 的,所以关键边的个数需要限制到 O(n) 级别;对于非关键边,因为有 m 条,需要在 O(n) 内求出路径长度,当然这道题比较神奇,可以 O(1) 就求出。

考虑如何处理非关键边,这里直接上结论:选择翻转边 (u,v) 的最终代价为 D+\min(dis(1,n),dis(1,v)+C+dis(u,n))+\min(dis(n,1),dis(n,v)+C+dis(u,1))。(dis(u,v) 就是 u \to v 的最短距离)

下面说明为什么这个结论是对的。(这里也是其它题解可能讲的不是很清楚的地方,当然这里确实不好说,因为关键边和非关键边是相互依存相互促进的,接下来这一大段话也会顺便讲掉如何处理关键边。)

首先需要给出关键边的定义(上面的那个不是定义,只是一个大概的理解):在 1\to i,i \to n,n \to i,i \to 1 (1 \le i \le n) 的其中一个的最短路径上的边。

容易发现,从一个定点出发(或一个定点结束)到其它每个点的最短路径构成一棵树(如果有两条长度相同的就随便选一条),这样关键路径的条数就是上面这 4 棵最短路径树上的边,是 O(n) 级别,满足要求。

有了定义,再回头看那个式子,这里就只看 1 \to n 了,因为 n \to 1 也同理。翻转边 (u,v)1 \to n 代价是 \min(dis(1,n),dis(1,v)+C+dis(u,n))

证明这个是对的,其实就需要两步:为什么能取到,为什么是最小的。式子中的 dis(1,n) 就是选择不经过 (u,v) 的翻转边,(这里可能有人会问,为什么不经过这条边还要翻转,原因是可能去的时候不经过,回来的时候就需要经过了。)因为 (u,v) 不是关键边,所以 dis(1,n) 不会改变,可以直接用;式子中的 dis(1,v)+C+dis(u,n) 就是选择经过 (u,v) 的翻转边,因为不是关键边,翻转 (u,v) 不会改变 1 到任何点、任何点到 n 的最短路径长度,所以 dis(1,v),dis(u,n) 不会改变,然后因为要经过这条边,所以一定是先最短路径走到 v,再用这条边到 u,然后走最短路径到 n。以上两种情况(走与不走翻转边)包含了所有情况,而它们各自是自己情况的最小,所以它们的较小值就是最终的最小值。

有了这样的结论,只需要先预处理出 dis(1,i),dis(i,n),dis(n,i),dis(i,1),便可以 O(1) 计算翻转非关键边的答案。

至此,我们分当前翻转的边是不是关键边两种情况进行了处理,这道题也就大功告成了,总复杂度 O(n^3+m)

写法的技巧和细节

这道题的代码并不好写,不然我也不至于调 2h,于是在这里总结了一些写法上的小技巧和容易出错的细节。

dij 不要用优先队列优化,那样是 m \log n,不如暴力的 n^2。同时别用邻接表,用邻接矩阵,方便翻转边。

关键边的求法:把每个点最终dis 是从哪条边转移过来的记录一下,这些边就是关键边。

如何求 dis(i,n),dis(i,1):建立反图,就可以以 n,1 为起点求了。不过因为这道题是邻接矩阵,方便操作,也可以在 dij 中做手脚,详细可以看下面的代码。

如何处理重边:邻接矩阵是显示不出来重边的,所以在求邻接矩阵时,需要保留边权最小的边,不过这还没完,因为边权最小的边可能被翻转掉,所以还需要记录边权第二小的边,这样可以在边权最小的边翻转后暂时顶替它的位置。但处理重边需要注意,边权大的边只能在邻接矩阵中被忽略,但也需要被存起来,因为这条边的 D 可能很小,把这条边翻转得到的答案可能更优,也就是说在翻转边的时候这条边也需要被考虑进来。

以上细节可能不全,那下面的代码和注释应该能解决你的所有疑惑。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int MAXN=210,MAXM=5e4+10;
int n,m,from[MAXN]; //from[i] 表示最终的 dis[i] 是从哪条边转移来的
ll c[MAXN][MAXN],c2[MAXN][MAXN]; //邻接矩阵的最小值和次小值
ll dist[4][MAXN],tdis[MAXN]; //dist[0,1,2,3] 分别是 1->i,i->n,n->i,i->1 的最短距离
bool vis[MAXN],key[MAXN][MAXN],used[MAXN][MAXN]; //key 记录关键边
struct Edge{
    int u,v;ll C,D;
}e[MAXM];
int read(){
    int num=0,sign=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')sign=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')num=num*10+ch-'0',ch=getchar();
    return num*sign;
}
void dij(int st,ll *dis,int op){ //op=0 是正图,op=1 是反图,op=2 表示翻转一条边的正图
    memset(vis,0,sizeof(vis));
    for(int i=0;i<=n;i++) dis[i]=2e18;
    dis[st]=0;
    while(1){
        int u=0;
        for(int i=1;i<=n;i++)
            if(!vis[i]&&dis[u]>dis[i]) u=i;
        if(!u) break;
        vis[u]=1;
        for(int i=1;i<=n;i++){
            ll w=(op==1)?c[i][u]:c[u][i]; //处理正图和反图
            if(!vis[i]&&dis[u]+w<dis[i]){
                dis[i]=dis[u]+w;
                if(op<2) from[i]=u; //记录当前转移来的边,但这未必是最终的
            }
        }
    }
    if(op<2) for(int i=1;i<=n;i++) key[from[i]][i]=1; //记录关键边
}
int main(){
    n=read(),m=read();
    memset(c,0x3f,sizeof(c)),memset(c2,0x3f,sizeof(c2));
    for(int i=1;i<=n;i++) c[i][i]=0;
    for(int i=1;i<=m;i++){
        int u=read(),v=read(),C=read(),D=read();
        e[i].u=u,e[i].v=v,e[i].C=C,e[i].D=D;
        if(c[u][v]>C) c2[u][v]=c[u][v],c[u][v]=C; //更新邻接矩阵最大值和次大值
        else if(c2[u][v]>C) c2[u][v]=C; //更新次大值
    }
    dij(n,dist[1],1),dij(1,dist[3],1); //预处理 4 个最短路
    dij(1,dist[0],0),dij(n,dist[2],0);
    ll ans=dist[0][n]+dist[2][1]; //不翻转边
    for(int i=1;i<=m;i++){
        int u=e[i].u,v=e[i].v;ll C=e[i].C,D=e[i].D;
        if(key[u][v]&&C==c[u][v]&&!used[u][v]){ //是关键边
            ll tmpu=c[u][v],tmpv=c[v][u];
            c[v][u]=min(c[v][u],C); //翻转边 (u,v)
            c[u][v]=(c[u][v]==C)?c2[u][v]:c[u][v]; //如果最小边被翻转,次小边顶替
            ll sum=D;
            dij(1,tdis,2);sum+=tdis[n]; //重算距离
            dij(n,tdis,2);sum+=tdis[1];
            ans=min(ans,sum); //更新答案
            c[u][v]=tmpu,c[v][u]=tmpv; //回溯
            used[u][v]=1;
        }
        else ans=min(ans,(ll)D+
            min(dist[0][n],dist[0][v]+C+dist[1][u])+
            min(dist[2][1],dist[2][v]+C+dist[3][u])); //套公式
    }
    printf("%lld\n",(ans>=1e18)?-1:ans);
    return 0; //华丽结束
}

不要忘记点赞~