从 WC2009 的一道题入手,思考网格图最短路问题的简单维护手法

· · 题解

2023/5/25,最近有点忙,有时间改一下博客风格。

2024/7/4,重开!前面的代码懒得改了,看着办吧(悲),我好鸽啊,不过反正也没人理我,就这样吧。

事情是这样的,作者退役已久,连写博客的风格都换了,一天偶然打开一道题,发现所有题解(尤其是洛谷题解)的风格大概是这样的:

看起来十分劝退,但其实目前的我们处理这些问题,只需要 NOIP 级别的知识和思维。

问题引入

一个 n\times m 的网格,设第 ij 列的非负权值为 w_{i,j},一个最短路查询形如:(a,b,c,d),我们要求的即为:从 (a,b)(c,d) 的一条路径,满足路径上相邻两个格子在原网格图上仍然相邻,最小化路径的权值和。

最简单的问题

只有一次最短路查询。

额,这个比较简单,直接跑迪杰斯特拉算法就好了,时间复杂度 O(nm\log),我也没什么更优的做法。

#include<bits/stdc++.h>
#define 上(子,丑,寅) for(int 子=丑;子<=寅;++子) 
using 长整型=long long; 
const int 行总=11e5,列总=50,变行[]={0,1,0,-1},变列[]={1,0,-1,0};
const 长整型 无穷大=1ll<<61; 
using namespace std;
int 行数,列数,权重[行总][列总],起行,起列,终行,终列;
长整型 距离[行总][列总];
#define 三元组 tuple<int,int,int> 
priority_queue<三元组,vector<三元组>,greater<三元组>>小根堆; 
int main()
{
    cin>>行数>>列数;
    上(甲,1,行数)
        上(乙,1,列数)
            cin>>权重[甲][乙],距离[甲][乙]=无穷大;
    cin>>起行>>起列>>终行>>终列;
    小根堆.emplace(距离[起行][起列]=权重[起行][起列],起行,起列);
    while(!小根堆.empty())
    {
        长整型 距=get<0>(小根堆.top());
        int 行=get<1>(小根堆.top()),列=get<2>(小根堆.top());
        小根堆.pop();
        if(距离[行][列]<距)continue;
        上(甲,0,3)
        {
            int 下行=行+变行[甲],下列=列+变列[甲];
            if(1<=下行&&下行<=行数&&1<=下列&&下列<=列数&&距+权重[下行][下列]<距离[下行][下列])
                小根堆.emplace(距离[下行][下列]=距+权重[下行][下列],下行,下列); 
        }
    }
    cout<<距离[终行][终列]; 
    return 0;
}

静态查询

q 次最短路查询。

如果 m 比较小那就有比暴力迪杰斯特拉看起来更好的做法。

如果是这样我们先假定不像某些毒瘤一样强制在线,这样我们就可以用优秀的离线算法,以下只是其中一个粗略的例子。

选定中间一列,假设部分询问的最短路径经过这一列,用迪杰斯特拉 O(nm^2\log) 求得这一列到整个网格图的最短距离。

然后对于每个询问,O(m) 枚举这一列所有可能的中转点更新答案。

那么剩下所有可能的情况,最短路都不会经过中间一列,可以分解成两个不相交的子问题求解,复杂度即为 O(nm^2\log^2+qm\log)

#include<bits/stdc++.h>
#define 上(子,丑,寅) for(int 子=丑;子<=寅;++子)
#define 下(子,丑,寅) for(int 子=丑;子>=寅;--子)
using 长整型=long long; 
const int 行总=11e5,列总=50,问总=11e5,变行[]={0,1,0,-1},变列[]={1,0,-1,0};
const 长整型 无穷大=1ll<<61; 
using namespace std;
int 行数,列数,问数,权重[行总][列总];
长整型 距离[行总][列总];
struct 询问
{
    int 起行,起列,终行,终列,标号;
    长整型 答案; 
    bool operator < (询问 甲) const
    {
        return 标号<甲.标号; 
    }
}问[问总];
#define 三元组 tuple<int,int,int> 
priority_queue<三元组,vector<三元组>,greater<三元组>>小根堆;
void 最短路(int 左端,int 右端,int 起行,int 起列)
{
    上(甲,左端,右端)
        fill(距离[甲]+1,距离[甲]+列总+1,无穷大); 
    小根堆.emplace(距离[起行][起列]=权重[起行][起列],起行,起列);
    while(!小根堆.empty())
    {
        长整型 距=get<0>(小根堆.top());
        int 行=get<1>(小根堆.top()),列=get<2>(小根堆.top());
        小根堆.pop();
        if(距离[行][列]<距)continue;
        上(甲,0,3)
        {
            int 下行=行+变行[甲],下列=列+变列[甲];
            if(左端<=下行&&下行<=右端&&1<=下列&&下列<=列数&&距+权重[下行][下列]<距离[下行][下列])
                小根堆.emplace(距离[下行][下列]=距+权重[下行][下列],下行,下列); 
        }
    }
}
void 分治(int 左端,int 右端,int 前问,int 后问)
{
    if(左端>右端||前问>后问)return;
    int 中列=(左端+右端)/2,分左=前问,分右=后问;
    上(甲,1,列数)
    {
        最短路(左端,右端,中列,甲);
        上(乙,前问,后问)
            问[乙].答案=min(问[乙].答案,距离[问[乙].起行][问[乙].起列]+距离[问[乙].终行][问[乙].终列]-权重[中列][甲]);
    }
    上(甲,前问,后问)
        if(问[甲].起行<中列&&问[甲].终行<中列)
            swap(问[分左++],问[甲]);
    下(甲,后问,分左)
        if(问[甲].起行>中列&&问[甲].终行>中列)
            swap(问[分右--],问[甲]);
    分治(左端,中列-1,前问,分左-1),
    分治(中列+1,右端,分右+1,后问); 
} 
int main()
{
    cin>>行数>>列数;
    上(甲,1,行数)
        上(乙,1,列数)
            cin>>权重[甲][乙];
    cin>>问数;
    上(甲,1,问数)
        cin>>问[甲].起行>>问[甲].起列>>问[甲].终行>>问[甲].终列,
        问[甲].标号=甲,问[甲].答案=无穷大;
    分治(1,行数,1,问数);
    sort(问+1,问+问数+1);
    上(甲,1,问数)
        cout<<问[甲].答案<<'\n'; 
    return 0;
}

修改和查询

q 次操作,包括两种操作:(1,a,b,c),修改位置 (a,b) 的点权为 c(2,a,b,c,d),查询位置 (a,b) 到位置 (c,d) 的最短路径。

如果 m 真的比较小,我们还是有比暴力跑迪杰斯特拉更优的算法的。

我们对 l 行到 r 行建立分治结构,比如线段树,在每个节点上,我们维护三个 m\times m 的方阵 A,B,C,其中:

那么线段树节点也具有的可合并性,对于查询 (a,b,c,d),我们提出 [1,a-1],[a,b],[b+1,n] 三个区间的信息,通过与线段树节点信息合并类似的做法,就可以快速求解。

那么线段树节点如何进行信息合并呢,一种朴素的想法即为,把这 3m 个节点暴力连边,建图,跑弗洛伊德全源最短路,那么我们即可做到 O(nm^3\log+qm^3\log) 的时间复杂度与 O(nm^2) 的空间复杂度,理论复杂度比题解更优(原因见下文),但是常数为 3^3=27,略大,无法通过此题(我算了下好像只能跑 2500 的样子),用于例题可得 80 分。

#include<bits/stdc++.h>
#define up(a,b,c) for(int a=b;a<=c;++a)
using namespace std;
using ll=long long;
const int N=11e4,E=6;
const ll inf=1e13;
int n,m,q,d[N][E];
ll U[3*E][3*E];
struct mat
{
    ll a[E][E];
    void t0()
    {
        up(i,0,m-1)fill(a[i],a[i]+m,inf);
    }
    ll *operator [](const int x) {return a[x];}
    mat()
    {
        t0();
    }
    void t1()
    {
        t0();
        up(i,0,m-1)a[i][i]=0;
    }
}r1;
struct mad
{
    mat A,B,C;
    mad operator = (const int *d)
    {
        int s[E+1];s[0]=0;
        up(i,0,m-1)s[i+1]=s[i]+d[i];
        up(i,0,m-1)up(j,0,m-1)
            A[i][j]=B[i][j]=C[i][j]=i>j?s[i+1]-s[j]:s[j+1]-s[i];
        return *this;
    }
    mad operator + (const mad &b) const
    {
        mad res;
        up(i,0,3*m-1)fill(U[i],U[i]+3*m,inf);
        up(i,0,m-1)up(j,0,m-1)
            U[i][j]=A.a[i][j],
            U[i][j+m]=U[j+m][i]=B.a[i][j],
            U[i+m][j+m]=min(C.a[i][j],b.A.a[i][j]),
            U[i+m][j+m*2]=U[j+m*2][i+m]=b.B.a[i][j],
            U[i+m*2][j+m*2]=b.C.a[i][j];
        up(k,0,m*3-1)up(i,0,m*3-1)up(j,0,m*3-1)
            U[i][j]=min(U[i][j],U[i][k]+U[k][j]);
        up(i,0,m-1)up(j,0,m-1)
            res.A[i][j]=U[i][j],
            res.B[i][j]=U[i][j+m*2],
            res.C[i][j]=U[i+m*2][j+m*2];
        return res;
    }
}nd[N<<2];
void change(int l,int r,int k,int lb,int rb)
{
    if(r<lb||rb<l)return;
    if(l==r)nd[k]=d[l];
    else
    {
        int mid=(l+r)>>1;
        change(l,mid,k<<1,lb,rb),
        change(mid+1,r,k<<1|1,lb,rb);
        nd[k]=nd[k<<1]+nd[k<<1|1];
    }
}
mad qry(int l,int r,int k,int lb,int rb)
{
    if(lb<=l&&r<=rb)return nd[k];
    int mid=(l+r)>>1;
    if(rb<=mid)return qry(l,mid,k<<1,lb,rb);
    if(mid<lb)return qry(mid+1,r,k<<1|1,lb,rb);
    return qry(l,mid,k<<1,lb,rb)+qry(mid+1,r,k<<1|1,lb,rb);
}
ll solve(int l,int x,int r,int y)
{
    if(l>r)swap(l,r),swap(x,y);
    mad Z=qry(1,n,1,l,r);
    mat L,R,M,Lr,lR;
    M=Z.B;
    if(l>1)
    {
        mad X=qry(1,n,1,1,l-1);
        L=X.C,lR=(X+Z).C;
    }
    else L=r1,lR=Z.C;
    if(r<n)
    {
        mad Y=qry(1,n,1,r+1,n);
        R=Y.A,Lr=(Z+Y).A;
    }
    else R=r1,Lr=Z.A;
    up(i,0,m-1)up(j,0,m-1)
        U[i][j]=min(Lr.a[i][j]-d[l][j],L.a[i][j]+d[l][i]),
        U[i][j+m]=M.a[i][j]-d[r][j],
        U[j+m][i]=M.a[i][j]-d[l][i],
        U[i+m][j+m]=min(lR.a[i][j]-d[r][j],R.a[i][j]+d[r][i]);
    up(k,0,m*2-1)up(i,0,m*2-1)up(j,0,m*2-1)
        U[i][j]=min(U[i][j],U[i][k]+U[k][j]);
    return U[x][y+m]+d[r][y];
}
int main()
{
    cin>>n>>m;
    r1.t1();
    up(j,0,m-1)up(i,1,n)cin>>d[i][j];
    change(1,n,1,1,n);
    cin>>q;
    while(q--)
    {
        int op,a,b,c,e;
        cin>>op>>a>>b>>c;
        if(op==1)d[b][a-1]=c,change(1,n,1,b,b);
        else cin>>e,cout<<solve(b,a-1,e,c-1)<<'\n';
    }
    return 0;
}

那么有没有办法使得常数更小呢?这里感谢与 Xopered 的交流,祝愿其信息学之路光芒璀璨。

就是这题

有一个做法只适用于 m\le 6 的情况,但常数较小,介绍一下。

为了卡常,注意到一个事实,答案一定不会爆 32 位有符号整型变量。

考虑当 m=6 的特殊情况,显然从相邻的格子再绕回明显不优,路径重复或交叉明显不优,所以不管点权变成什么样子,答案的最短路径从一侧穿出绕回只会发生一次。 那么,我们用这个来合并信息就明显简单很多了!

合并时:从左边到左边,除了左边的信息外,还要考虑跨过一次的信息(2 次矩阵乘法),右边到右边同理,而从左边到右边要考虑 S 型的情况(3 次矩阵乘法),一共 7 次矩阵乘法。

计算答案时,相对比较复杂,与合并不太一样,要注意以下的几种情况及其组合(其实就是考虑左/右边界在漫游整张图之后回到左/右边界的情况),如果实在讨论不出来,由于这里不是复杂度瓶颈,你搞一个迪杰斯特拉,准能搞出来,比我的写法运行时间还短,不要(像我一样)被这一步卡了,这其实是最简单的一步。 代码如下,并不长:

#include<bits/stdc++.h>
#define up(a,b,c) for(int a=b;a<=c;++a)
using namespace std;
const int N=11e4;
int n,q,d[N][6];
struct mat
{
    int a[6][6];
    void t0()
    {
        fill(a[0],a[0]+36,2e9);
    }
    int *operator [](const int x) {return a[x];}
    mat()
    {
        t0();
    }
    void t1()
    {
        t0();
        up(i,0,5)a[i][i]=0;
    }
    mat rev() const
    {
        mat res=*this;
        up(i,0,5)up(j,0,i-1)swap(res.a[i][j],res.a[j][i]);
        return res;
    }
    mat operator + (const mat &b) const
    {
        mat res;
        up(i,0,5)up(j,0,5)res.a[i][j]=min(a[i][j],b.a[i][j]);
        return res;
    }
    mat operator * (const mat &b) const
    {
        mat res;
        up(i,0,5)up(j,0,5)up(k,0,5)
            res.a[i][j]=min(res.a[i][j],a[i][k]+b.a[k][j]);
        return res;
    }
}r1;
struct mad
{
    mat A,B,C;
    mad operator = (const int *d)
    {
        int s[7];s[0]=0;
        up(i,0,5)s[i+1]=s[i]+d[i];
        up(i,0,5)up(j,0,5)
            A[i][j]=B[i][j]=C[i][j]=i>j?s[i+1]-s[j]:s[j+1]-s[i];
        return *this;
    }
    mad operator + (const mad &b) const
    {
        mad res;
        res.A=A+B*b.A*B.rev(),
        res.B=B*(r1+b.A*C)*b.B,
        res.C=b.C+b.B.rev()*C*b.B;
        return res;
    }
}nd[N<<2];
void change(int l,int r,int k,int lb,int rb)
{
    if(r<lb||rb<l)return;
    if(l==r)nd[k]=d[l];
    else
    {
        int mid=(l+r)>>1;
        change(l,mid,k<<1,lb,rb),
        change(mid+1,r,k<<1|1,lb,rb);
        nd[k]=nd[k<<1]+nd[k<<1|1];
    }
}
mad qry(int l,int r,int k,int lb,int rb)
{
    if(lb<=l&&r<=rb)return nd[k];
    int mid=(l+r)>>1;
    if(rb<=mid)return qry(l,mid,k<<1,lb,rb);
    if(mid<lb)return qry(mid+1,r,k<<1|1,lb,rb);
    return qry(l,mid,k<<1,lb,rb)+qry(mid+1,r,k<<1|1,lb,rb);
}
int solve(int l,int x,int r,int y)
{
    if(l>r)swap(l,r),swap(x,y);
    mad Z=qry(1,n,1,l,r);
    mat L,R,M,Lr,lR;
    M=Z.B;
    if(l>1)
    {
        mad X=qry(1,n,1,1,l-1);
        L=X.C,lR=(X+Z).C;
    }
    else L=r1,lR=Z.C;
    if(r<n)
    {
        mad Y=qry(1,n,1,r+1,n);
        R=Y.A,Lr=(Z+Y).A;
    }
    else R=r1,Lr=Z.A;
    return ((r1+Lr*L)*M*(r1+R*lR))[x][y];
}
int main()
{
    cin>>n;
    r1.t1();
    up(j,0,5)up(i,1,n)cin>>d[i][j];
    change(1,n,1,1,n);
    cin>>q;
    while(q--)
    {
        int op,a,b,c,e;
        cin>>op>>a>>b>>c;
        if(op==1)d[b][a-1]=c,change(1,n,1,b,b);
        else cin>>e,cout<<solve(b,a-1,e,c-1)<<'\n';
    }
    return 0;
}

真的没有更好的做法了吗

事实上常数瓶颈在合并部分,我们重点考虑它(最后计算的部分我懒得想了直接用又快又好迪杰斯特拉了)

你发现合并两个段的部分如果出现反复横跳,直接使用弗洛伊德算法就好了,因此,实际上只需要对上传部分做微不足道的修改就可以得到不需要任何分类讨论的一般性的程序。

欢迎各位来检验该做法的正确性!

#include<bits/stdc++.h>
#define up(a,b,c) for(int a=b;a<=c;++a)
using namespace std;
using ll=long long;
const int N=11e4,E=7;
const ll inf=1e13;
int n,m,q,d[N][E];
ll U[2*E][2*E],dis[2*E+1];
bool ok[2*E+1]; 
struct mat
{
    ll a[E][E];
    void t0()
    {
        up(i,0,m-1)fill(a[i],a[i]+m,inf);
    }
    ll *operator [](const int x) {return a[x];}
    mat()
    {
        t0();
    }
    void t1()
    {
        t0();
        up(i,0,m-1)a[i][i]=0;
    }
    mat rev() const
    {
        mat res=*this;
        up(i,0,m-1)up(j,0,i-1)swap(res.a[i][j],res.a[j][i]);
        return res;
    }
    mat nrm() const
    {
        mat res=*this;
        up(k,0,m-1)up(i,0,m-1)up(j,0,m-1)
            res.a[i][j]=min(res.a[i][j],res.a[i][k]+res.a[k][j]);
        return res;
    }
    mat operator + (const mat &b) const
    {
        mat res;
        up(i,0,m-1)up(j,0,m-1)res.a[i][j]=min(a[i][j],b.a[i][j]);
        return res;
    }
    mat operator * (const mat &b) const
    {
        mat res;
        up(i,0,m-1)up(j,0,m-1)up(k,0,m-1)
            res.a[i][j]=min(res.a[i][j],a[i][k]+b.a[k][j]);
        return res;
    }
}r1;
struct mad
{
    mat A,B,C;
    mad operator = (const int *d)
    {
        int s[m+1];s[0]=0;
        up(i,0,m-1)s[i+1]=s[i]+d[i];
        up(i,0,m-1)up(j,0,m-1)
            A[i][j]=B[i][j]=C[i][j]=i>j?s[i+1]-s[j]:s[j+1]-s[i];
        return *this;
    }
    mad operator + (const mad &b) const
    {
        mad res;
        mat x=(b.A+C).nrm();
        res.A=A+B*x*B.rev(),
        res.B=B*(r1+x)*b.B,
        res.C=b.C+b.B.rev()*x*b.B;
        return res;
    }
}nd[N<<2];
void change(int l,int r,int k,int lb,int rb)
{
    if(r<lb||rb<l)return;
    if(l==r)nd[k]=d[l];
    else
    {
        int mid=(l+r)>>1;
        change(l,mid,k<<1,lb,rb),
        change(mid+1,r,k<<1|1,lb,rb);
        nd[k]=nd[k<<1]+nd[k<<1|1];
    }
}
mad qry(int l,int r,int k,int lb,int rb)
{
    if(lb<=l&&r<=rb)return nd[k];
    int mid=(l+r)>>1;
    if(rb<=mid)return qry(l,mid,k<<1,lb,rb);
    if(mid<lb)return qry(mid+1,r,k<<1|1,lb,rb);
    return qry(l,mid,k<<1,lb,rb)+qry(mid+1,r,k<<1|1,lb,rb);
}
ll solve(int l,int x,int r,int y)
{
    if(l>r)swap(l,r),swap(x,y);
    mad Z=qry(1,n,1,l,r);
    mat L,R,M,Lr,lR;
    M=Z.B;
    if(l>1)
    {
        mad X=qry(1,n,1,1,l-1);
        L=X.C,lR=(X+Z).C;
    }
    else L=r1,lR=Z.C;
    if(r<n)
    {
        mad Y=qry(1,n,1,r+1,n);
        R=Y.A,Lr=(Z+Y).A;
    }
    else R=r1,Lr=Z.A;
    up(i,0,m-1)up(j,0,m-1)
        U[i][j]=min(Lr.a[i][j]-d[l][j],L.a[i][j]+d[l][i]),
        U[i][j+m]=M.a[i][j]-d[r][j],
        U[j+m][i]=M.a[i][j]-d[l][i],
        U[i+m][j+m]=min(lR.a[i][j]-d[r][j],R.a[i][j]+d[r][i]);
    fill(dis,dis+2*m+1,inf),fill(ok,ok+2*m+1,0);
    dis[x]=0;
    up(i,0,m*2-1)
    {
        int u=2*m;
        up(j,0,m*2-1)
            if(dis[j]<dis[u]&&!ok[j])u=j;
        ok[u]=1;
        up(j,0,m*2-1)dis[j]=min(dis[j],dis[u]+U[u][j]);
    }
    return dis[y+m]+d[r][y];
}
int main()
{
    cin>>n>>m;
    r1.t1();
    up(j,0,m-1)up(i,1,n)cin>>d[i][j];
    change(1,n,1,1,n);
    cin>>q;
    while(q--)
    {
        int op,a,b,c,e;
        cin>>op>>a>>b>>c;
        if(op==1)d[b][a-1]=c,change(1,n,1,b,b);
        else cin>>e,cout<<solve(b,a-1,e,c-1)<<'\n';
    }
    return 0;
}