如何优化SPFA

· · 算法·理论

前置芝士:SPFA是 Bellman-Ford 算法的优化。but,它的最坏时间复杂度为 O(nm),且总被某些不愿意透露姓名的邪恶善良出题入卡没(网格图、菊花图等数据)。

众所周知 SPFA 优化主要是双端队列实现的STL、LLL优化,使其接近 Dijkstra 的优先队列。这样做完全丢失了 SPFA 的精华 ,真是去其糟粕,取其精华,也让大家对 SPFA 彻底失望。大家应该都听说过“关于 SPFA ,它已经si了”,这句话像一把刀,插在所有喜爱 SPFA 的人士中。

本文提出了多种优化,使其不被某些不愿意透露姓名的善良出题人杀死。

1 拓扑排序优化

我们可以发现,对于某个节点,如果一个前驱松弛该节点,那它在传统 SPFA 中就会入队列,持续松弛后续节点,那如果此时又有另一个前驱节点松弛该节点,那我们前面的松弛操作就要再来一次了。

但是,如果我们尽量让该节点的所有前驱节点(或大部分前驱节点)松弛完该节点后,我们再让它入队。这个优化有点类似于拓扑排序,但是原图毕竟不是严格的 DAG,怎么转化又成了问题。我们通过 Tarjan(不了解的可以去了解一下)来消除返祖边(注意这里不消除横叉边带来的前驱节点,因为依然能保证图为 DAG)带来的前驱节点的贡献。这里有个坑点,就是前驱节点一定是起点能到的。

这个优化在近似有向无环图(DAG)的图中表现良好,但是在一些稠密图就不行了,我们需要下一个优化来解决这个问题。

tarjan 部分,相比于模板 tarjan 有所改动

void tarjan(int x){
    g[x].dfn=g[x].low=++l;
    stk[++top]=x;
    k[x]=w[x]=1;//k[]:记录当前递归了哪些节点
    for (int i=h[x];i;i=a[i].t){v=a[i].x;//v为全局变量,可以看最下面的完整代码(屎山代码不要介意)
        if (a[i].qq==0) continue;//这条边不属于我们构建出的残余网络,直接跳过,下个优化会将怎么构建残余网络
        if (g[v].dfn&&k[v]) --p[v];//g[v].dfn如果不等于0,代表这条边为返祖边或横叉边,k[]如果等于0,则代表这条边只是横叉边,不是返祖边。如果这条边是返祖边,那就消除它对v的贡献,其中p[]就是我们拓扑排序中常用的存前驱节点数量的数组
        if (!g[v].dfn) tarjan(v);
        if (w[v]) g[x].low=min(g[x].low,g[v].low);
    }if (g[x].dfn==g[x].low){
        int y=0;
        while (x!=y) y=stk[top--],w[y]=0;
    }k[x]=0;
}

spfa 主函数:

struct spf{
    int x,f;//f:节点是否第一次入队
};
inline bool spfa(int s){
    queue <spf> sta;
    init(s);sta.push({s,1});
    while (!sta.empty()){
        spf tp=sta.front();u=tp.x;
        sta.pop();w[u]=0;
        if (c[u]>n) return 0;
        for (int i=h[u];i;i=a[i].t){
            l=0,v=a[i].x;
            if (ans[v]>ans[u]+a[i].u)
                c[v]=c[u]+1,ans[v]=ans[u]+a[i].u,fau[v]=u,l=1;
            if (!w[v]&&(p[v]==1&&tp.f||l&&p[v]==0)) w[v]=1,sta.push({v,p[v]}); 
            if (tp.f&&p[a[i].x]>0) --p[a[i].x];//tp.f:只有第一次入队的节点才算入p[]
        }
    }return 1;
}

2 Prim-like 预处理

在找最短路径时,如果一开始就有个大致正确的方向,后续的精确计算就会快很多。

我们的预处理借鉴了最小生成树算法 Prim 的思想:从源点开始,总是先走看起来最短的边

这样预处理有两个目的,

拓扑排序优化的劣势在于,遇到两个点的环,即无向边(只是可能两边边权不一样),我们必须选择一个节点作另一个节点的前驱节点,在裸拓扑排序优化中,遇到大量“无向边”的优化效率对建边的顺序极其依赖。但是如果借助当前优化,可以建出质量更好的 DAG,提升优化效率。

我们虽然是按 Prim 的思想对优先队列进行构建,再预处理,但是依然要保证松弛操作的成功,即仍要判断 ans[v] > ans[u] + a[i].u 松弛失败则不加入优先队列。还有不要把这个优化和 Prim 搞混了,Prim 处理最小生成树问题时是无向图,而我们加入这个优化时并不局限于无向图。

两个优化实现:

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define add(x,y,z) (a[++cnt]={y,h[x],0,z},h[x]=cnt)
const int N=8e5+8,M=2e6+8;
const ll C=0x3f3f3f3f3f3f3f3f;
struct G{int x,t,qq;ll u;} a[M];
int n,m,h[N],cnt=0;
struct Tar{int dfn,low;} g[N];
ll ans[N],z;
int p[N],c[N],l,stk[N],top=0,w[N],k[N],u,v,fau[N];
void tarjan(int x){
    g[x].dfn=g[x].low=++l;
    stk[++top]=x;k[x]=w[x]=1;
    for (int i=h[x];i;i=a[i].t){v=a[i].x;
        if (a[i].qq==0) continue;
        if (g[v].dfn&&k[v]) --p[v];
        if (!g[v].dfn) tarjan(v);
        if (w[v]) g[x].low=min(g[x].low,g[v].low);
    }if (g[x].dfn==g[x].low){
        int y=0;
        while (x!=y) y=stk[top--],w[y]=0;
    }k[x]=0;
}
inline void init(int s){
    memset(g,0,sizeof(g));
    memset(p,0,sizeof(p));
    memset(w,0,sizeof(w));
    memset(c,0,sizeof(c));
    memset(fau,0,sizeof(fau));
    memset(ans,C,sizeof(ans));
    for (int i=1;i<=cnt;++i) a[i].qq=0;
    queue <int> q;
    ans[s]=0;
    q.push(s);l=0;c[s]=1;
    while (!q.empty()){
        u=q.front();q.pop();
        if (w[u]) continue;
        ++l;w[u]=1;if (l==n) break;
        for (int i=h[u];i;i=a[i].t){v=a[i].x;
            if (v==fau[u]) continue;
            ++p[v];a[i].qq=1;
            if (ans[v]>ans[u]+a[i].u){
                ans[v]=ans[u]+a[i].u,fau[v]=u,c[v]=c[u]+1;
                if (w[v]==0) q.push(v);
            }
        }
    }
    memset(w,0,sizeof(w));
    l=0,tarjan(s);
    memset(w,0,sizeof(w));
    w[s]=1;
}
struct spf{
    int x,f;
};
inline bool spfa(int s){
    queue <spf> sta;
    init(s);sta.push({s,1});
    while (!sta.empty()){
        spf tp=sta.front();u=tp.x;
        sta.pop();w[u]=0;
        if (c[u]>n) return 0;
        for (int i=h[u];i;i=a[i].t){
            l=0,v=a[i].x;
            if (ans[v]>ans[u]+a[i].u)
                c[v]=c[u]+1,ans[v]=ans[u]+a[i].u,fau[v]=u,l=1;
            if (!w[v]&&(p[v]==1&&tp.f||l&&p[v]==0)) w[v]=1,sta.push({v,p[v]}); 
            if (tp.f&&p[a[i].x]>0) --p[a[i].x];
        }
    }return 1;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    int s;
    cin >>n>>m>>s;
    for (int i=1,x,y;i<=m;++i){
        cin >>x>>y>>z;
        if (x==y) continue;
        add(x,y,z);
    }
    if (!spfa(s)) return cout <<"No answer.",0;
    for (int i=1;i<=n;++i) cout <<ans[i]<<' ';
    return 0;
}

3 自适应优先级调度(加速路径传播)

有了前面两个优化,我们已经可以 AC 大多数卡 SPFA 的图了,但这还不够。

传统 SPFA 按严格广搜序处理节点,但有些节点可能比其他节点更重要。就例如如果正确的路径比一个看似正确的近似最短路经长一点,那也就意味着,正确的答案总是比次短路慢一拍,就像你和父亲的年龄差是不变的。我们的优化能让算法能够 智能判断哪些节点应该优先处理,消除这种速度差异。

检测的方法也十分粗暴,直接给每个节点一个权值(初始为 0),将 SPFA 主循环的队列改为优先队列,当该节点被松弛且优先队列里不存在该节点时,让这个权值 +1。当然,考虑到 SPFA 是 Bellman-Ford 算法的队列优化,所以当两个节点权值一样时,让最短路径长度较小的节点先进行松弛操作。

实现(只包含此优化的 SPFA):

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define add(x,y,z) (a[++cnt]={y,h[x],z},h[x]=cnt)
const int N=1e6+8,M=5e6+8;
const ll C=0x3f3f3f3f3f3f3f3f;
struct G{int x,t;ll u;} a[M];
int n,m,h[N],cnt=0;
ll ans[N],z;
int c[N],w[N],t[N],u,v;
struct spf{
    int x,t,dep;
    inline bool operator<(const spf&other)const{
        if (t!=other.t) return t<other.t;
        return dep>other.dep;
    }
};
template <typename T>
class Priority_Queue:public priority_queue<T>{
    public:void reserve(size_t n){this->c.reserve(n);}
};
inline bool spfa(int s){
    memset(ans,C,sizeof(ans));
    memset(c,0,sizeof(c));
    memset(t,0,sizeof(t));
    Priority_Queue <spf> sta;
    sta.reserve(n);
    sta.push({s,1,1});
    w[s]=t[s]=c[s]=1;ans[s]=0;
    while (!sta.empty()){
        spf tp=sta.top();u=tp.x;
        sta.pop();w[u]=0;
        if (c[u]>n) return 0;
        for (int i=h[u];i;i=a[i].t){v=a[i].x;
            if (ans[v]>ans[u]+a[i].u){
                ans[v]=ans[u]+a[i].u;c[v]=c[u]+1;
                if (w[v]==0) w[v]=1,sta.push({v,++t[v],c[v]});
            }
        }
    }return 1;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    int s;
    cin >>n>>m>>s;
    for (int i=1,x,y,z;i<=m;++i){
        cin >>x>>y>>z;
        if (x==y) continue;
        add(x,y,z);
    }
    if (!spfa(s)) return cout <<"No answer.",0;
    for (int i=1;i<=n;++i) cout <<ans[i]<<' ';
    return 0;
}

4 周期性负环检测

传统 SPFA 要绕很多圈才能发现负环,浪费大量时间。即使有了前 3 个优化,遇到负环还是要寄,此时你需要定期判断有没有负环(其实严格来说,这属于 SPFA 判负环的优化)。

我们的优化采用定期检查的方式:每处理一定数量的边后,快速检查一下是否陷入了循环

在计算过程中,我们记录每个节点的 "最佳前驱"(即当前找到的最短路径是从哪个邻居来的)。如果这些前驱关系形成了一个圈,那么 只有一种可能,有负环。

实现:

inline bool check(){
    if (ks<m) return 0;//ks为操作记数,均摊操作
    ks=0;
    for (int i=1;i<=n;++i) dsu[i]=i;
    for (int i=1,dx,dy;i<=n;++i){
        dx=find(fau[i]),dy=find(i);
        if (dx==dy) return 1;//有负环
        dsu[dy]=dx;
    }
    return 0;
}

5 见证时刻的奇迹

我们在多种经典卡 SPFA 的图上进行了测试,包括:

【模板】单源最短路径(标准版) AC 记录

【模板】全源最短路(Johnson) AC 记录

[图论与代数结构 202] 最短路问题_2 AC 记录

自制的实验数据

优化后的 SPFA 在所有测试案例上均表现稳定,时间复杂度趋向于 O((n+m) \log n),且未出现性能退化现象。

6 code

屎山代码:

//拓扑排序、最小生成树思想与加速路径传播优化spfa
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define getchar_fread() (p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXN,stdin),p1==p2)?EOF:*p1++)
#define flush() (fwrite(buffer,1,q1+1,stdout),q1=-1)
#define add(x,y,z) (a[++cnt]={y,h[x],0,z},h[x]=cnt)
const int MAXN=1<<16,N=8e5+8,M=2e6+8;
const ll C=0x3f3f3f3f3f3f3f3f;
char *p1,*p2,buf[MAXN],buffer[MAXN];
int q1=-1,q2=MAXN-1,j=0;
inline void putchar_fwrite(char x){if (q1==q2) flush();buffer[++q1]=x;}
template<typename T>
inline void read(T &x){x=0;
    T c=getchar_fread(),f=0;
    while(!isdigit(c)){if(c=='-') f=1;c=(T)getchar_fread();}
    while(isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=(T)getchar_fread();
    if (f) x=~x+1;
}
template<typename T>
inline void write(T x){
    if (x<0) putchar_fwrite('-'),x=~x+1;
    T print[65];
    while (x) print[++j]=x%10,x/=10;
    if (j==0) print[++j]=0;
    while (j) putchar_fwrite(print[j--]^48);
}
struct G{int x,t,qq;ll u;} a[M];
int n,m,h[N],cnt=0,ks,dsu[N];
struct Tar{int dfn,low;} g[N];
ll ans[N],z;
int p[N],c[N],l,stk[N],top=0,w[N],k[N],t[N],u,v,fau[N];
void tarjan(int x){
    g[x].dfn=g[x].low=++l;
    stk[++top]=x;k[x]=w[x]=1;
    for (int i=h[x];i;i=a[i].t){v=a[i].x;
        if (a[i].qq==0) continue;
        if (g[v].dfn&&k[v]) --p[v];
        if (!g[v].dfn) tarjan(v);
        if (w[v]) g[x].low=min(g[x].low,g[v].low);
    }if (g[x].dfn==g[x].low){
        int y=0;
        while (x!=y) y=stk[top--],w[y]=0;
    }k[x]=0;
}
struct pri{
    int x;ll u;
    inline bool operator<(const pri&other)const{return u>other.u;}
};
struct spf{
    int x,f,t,dep;
    inline bool operator<(const spf&other)const{
        if (t!=other.t) return t<other.t;
        return dep>other.dep;
    }
};
template <typename T>
class Priority_Queue:public priority_queue<T>{
    public:void reserve(size_t n){this->c.reserve(n);}
};
int find(int x){return dsu[x]==x?x:dsu[x]=find(dsu[x]);}
inline void init(int s){
    memset(g,0,sizeof(g));
    memset(p,0,sizeof(p));
    memset(w,0,sizeof(w));
    memset(t,0,sizeof(t));
    memset(c,0,sizeof(c));
    memset(fau,0,sizeof(fau));
    memset(ans,C,sizeof(ans));
    for (int i=1;i<=cnt;++i) a[i].qq=0;
    Priority_Queue <pri> q;
    q.reserve(m);ans[s]=0;
    q.push({s,0});l=0;c[s]=1;
    while (!q.empty()){
        u=q.top().x;q.pop();
        if (w[u]) continue;
        ++l;w[u]=1;if (l==n) break;
        for (int i=h[u];i;i=a[i].t){v=a[i].x;
            if (v==fau[u]) continue;
            ++p[v];a[i].qq=1;
            if (ans[v]>ans[u]+a[i].u){
                ans[v]=ans[u]+a[i].u,fau[v]=u,c[v]=c[u]+1;
                if (w[v]==0) q.push({v,a[i].u});
            }
        }
    }
    memset(w,0,sizeof(w));
    l=0,tarjan(s);
    memset(w,0,sizeof(w));
    w[s]=t[s]=1;ks=0;
}
inline bool check(){
    if (ks<m) return 0;
    ks=0;
    for (int i=1;i<=n;++i) dsu[i]=i;
    for (int i=1,dx,dy;i<=n;++i){
        dx=find(fau[i]),dy=find(i);
        if (dx==dy) return 1;
        dsu[dy]=dx;
    }
    return 0;
}
inline bool spfa(int s){
    Priority_Queue <spf> sta;
    init(s);sta.reserve(n);
    sta.push({s,1,1,1});
    while (!sta.empty()){
        spf tp=sta.top();u=tp.x;
        sta.pop();w[u]=0;
        if (c[u]>n||check()) return 0;
        for (int i=h[u];i;i=a[i].t){
            l=0,v=a[i].x;++ks;
            if (ans[v]>ans[u]+a[i].u)
                c[v]=c[u]+1,ans[v]=ans[u]+a[i].u,fau[v]=u,l=1;
            if (!w[v]&&(p[v]==1&&tp.f||l&&p[v]==0)) w[v]=1,sta.push({v,p[v],++t[v],c[v]}); 
            if (tp.f&&p[a[i].x]>0) --p[a[i].x];
        }
    }return 1;
}
signed main(){
    int s;
    read(n),read(m),read(s);
    for (int i=1,x,y;i<=m;++i){
        read(x),read(y),read(z);
        if (x==y) continue;
        add(x,y,z);
    }
    if (!spfa(s)) return cout <<"No answer.",0;
    for (int i=1;i<=n;++i) write(ans[i]),putchar_fwrite(' ');
    flush();
    return 0;
}