[JOI 2020 Final] オリンピックバス 题解

· · 题解

[JOI 2020 Final] オリンピックバス

题意

题面写得很清楚明白了。

题解

看到这道题的第一眼,会很显然地想到每次枚举一条翻转的边跑最短路,但这样的复杂度显然非常的大,只能拿到最低档的 5 分(捆绑数据要是打炸了还一分没有)。 分析一下数据范围,我们发现如果有一种 O(n^{3}) 左右的算法就可以不多不少刚刚好跑过。 首先我们明确包括一共有四种从1\sim n再从n\sim 1的走法:

(有点绕口) 再想想我们需要在没有任何边翻转的原图中先得到那些东西:

  1. 1 号城市到 n 号个城市的最短路。
  2. 每个城市到 1 号城市的最短路。
  3. 每个城市到 n 号城市的最短路。

前两个直接在原图跑最短路,后两个建个反图跑最短路即可。

至于为什么要建反图,是因为如果你把某条边翻转了之后,原先有在最短中的边就不存在了,所以不能直接从前两个的结果中赋值过来(还是不懂的可以画图手动模拟一下,很显然是不相同的)。

(那你说了那么多,还不是要枚举每条边去翻转???)

对于翻转操作,我们有一个奇技淫巧好的方法进行处理:因为如果我们没有翻转在 1\sim nn\sim 1 最短路上的边,上述第一中走法中是不会产生贡献的,而在最短路上的边翻转之后,对于我们当前求出的其他走法中也不会更优。

所以我们只需要在跑前两遍最短路时处理出在 1\sim nn\sim 1 最短路上的边,然后枚举这些边,暴力加边跑最短路,做第一种走法,其它的边就直接 O(1) 走后面三种走法,然后判断取最小值即可。

然后我们就可以限制枚举边翻转的复杂度在严格 O(n) 以内了。

实现有亿一点点细节,但是明确思路之后就不难写。

CODE

#include<cstdio>
#include<string>
#define R register int
#define N 205
#define M 50005
#define ll long long
#define inf 1234567891011
using namespace std;
struct G{ll to,next,id;ll w;}e[M];
int n,m,u[M],v[M],cnt,head[N],nxt[N],last[N];
ll c[M],d[M],ans1,ans2,ans,dist1[N],dist2[N][5];
bool bz[N],flag[M];
ll min(ll a,ll b) {return a<b?a:b;}
void read(int &x)
{
    x=0;ll f=1;char ch=getchar();
    while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();x*=f;
}
void read1(ll &x)
{
    x=0;ll f=1;char ch=getchar();
    while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();x*=f;
}
void add(int u,int v,ll w,int id)
{
    e[++cnt].to=v;e[cnt].w=w;e[cnt].id=id;
    e[cnt].next=head[u];head[u]=cnt;
}
void dij1(int s)
{
    for (R i=1;i<=n;++i) dist1[i]=inf,bz[i]=0,nxt[i]=0;
    dist1[s]=0;
    for (R i=1;i<=n;++i)
    {
        ll mn=inf;int k=n+1;
        for (R j=1;j<=n;++j) if (!bz[j] && dist1[j]<mn) mn=dist1[j],k=j;
        if (k==n+1) break;bz[k]=1;
        for (R j=head[k];j;j=e[j].next)
        {
            int v=e[j].to;
            if (dist1[v]>dist1[k]+e[j].w) dist1[v]=dist1[k]+e[j].w,nxt[v]=k,last[v]=j;
        }
    }
}
void dij2(int s,int num)
{
    for (R i=1;i<=n;++i) dist2[i][num]=inf,bz[i]=0;
    dist2[s][num]=0;
    for (R i=1;i<=n;++i)
    {
        ll mn=inf;int k=n+1;
        for (R j=1;j<=n;++j) if (!bz[j] && dist2[j][num]<mn) mn=dist2[j][num],k=j;
        if (k==n+1) break;bz[k]=1;
        for (R j=head[k];j;j=e[j].next)
        {
            int v=e[j].to;
            if (dist2[v][num]>dist2[k][num]+e[j].w) dist2[v][num]=dist2[k][num]+e[j].w;
        }
    }   
}
void build(ll u)
{
    while (u) flag[last[u]]=1,u=nxt[u];
}
int main()
{
    read(n);read(m);
    for (R i=1;i<=m;++i)
    {
        read(u[i]);read(v[i]);read1(c[i]);read1(d[i]);
        add(u[i],v[i],c[i],i);
    }
    dij1(1);ans1=dist1[n];build(n);
    dij1(n);ans2=dist1[1];build(1);ans=ans1+ans2;
    dij2(1,1);dij2(n,2);
    for (R i=1;i<=n;++i) head[i]=0;cnt=0;
    for (R i=1;i<=m;++i) add(v[i],u[i],c[i],i);
    dij2(1,3);dij2(n,4);
    for (R i=1;i<=m;++i)
    {
        if (flag[i]) 
        {
            for (R j=1;j<=n;++j) head[j]=0;cnt=0;
            for (R j=1;j<=m;++j)
                if (j==i) add(v[j],u[j],c[j],j);else add(u[j],v[j],c[j],j);
            dij1(1);ll tot1=dist1[n];dij1(n);ll tot2=dist1[1];ans=min(ans,tot1+d[i]+tot2);
        }
        else
        {
            ans=min(ans,ans1+dist2[v[i]][2]+c[i]+d[i]+dist2[u[i]][3]);
            ans=min(ans,ans2+dist2[v[i]][1]+c[i]+d[i]+dist2[u[i]][4]);
            ans=min(ans,dist2[v[i]][1]+c[i]+d[i]+dist2[u[i]][4]+dist2[v[i]][2]+c[i]+dist2[u[i]][3]);
        }
    }
    if (ans>=inf) printf("-1\n");else printf("%lld\n",ans);
    return 0;
}