P7500 题解

· · 题解

本题的思路非常巧妙,是一道不可多得的带思维的广搜好题,下面我将循序渐进地教大家如何解决本题(特别地,在本篇题解中,所有变量均会统一用小写字母表示)。

一开始,我粗略地看完题目之后,脑海中冒出来的第一个想法是跑 p 遍单源最短路,由于路径权值只有 01,因此可以简化成一个关于 bfs 的问题,每次扩展的范围就是所有经过当前车站的线路停靠的所有站点,对于每个询问,如果车站 s 能扩展到车站 t,就让 ans 加上对应的贡献值, 最后输出 ans。我花了 20 分钟时间将我这个想法成功地实现了出来:

#include<bits/stdc++.h>
using namespace std;
inline int read(){
    int x=0;
    char c=getchar();
    while(c<'0' || c>'9') c=getchar();
    while(c>='0' && c<='9'){
        x=(x<<3)+(x<<1)+(c^48);
        c=getchar();
    }
    return x;
}
const int N=100002,M=1002;
vector<int> line[M];
vector<int> station[N];
int f[N];
queue<int> q;
void bfs(){
    while(!q.empty()){
        int x=q.front();
        q.pop();
        for(auto l:station[x]){
            for(auto s:line[l]){
                if(f[s]==-1){
                    f[s]=f[x]+1;
                    q.push(s);
                }
            }
        }
    }
    return;
}
int main(){
    int n=read(),m=read(),p=read(),k,x,s,t,ans=0;
    for(int i=1;i<=m;i++){
        k=read();
        while(k--){
            x=read();
            line[i].push_back(x);
            station[x].push_back(i);
        }
    }
    while(p--){
        s=read(),t=read();
        memset(f,-1,sizeof(f));
        q.push(s);
        f[s]=0;
        bfs();
        if(f[t]!=-1) ans+=f[t];
    }
    printf("%d",ans);
    return 0;
}

Unfortunately,T 飞了。

仔细地分析了一下这个代码的时间复杂度,发现最坏情况下竟然O(pn^2) 的。很显然,这个时间复杂度级别的代码无论怎么优化都不可能卡过去。

此时,我们就需要改变解决本题的思路。我们在数据范围内发现到一个很奇怪的东西,每个车站停靠的线路数量不超过 50,然后再结合着 m\leq 1000 一起看,我们是否能将主要的时间复杂度转移到 m 上去呢?继续深入地思考一下,我们发现刚才的那个 bfs 貌似是按照地铁线路,一个一个车站地去扩展的,诶?那我们为什么不把车站放一边,直接按照地铁线路去扩展呢?

至此,正解大致的思路已经出来了,首先我们用一个邻接矩阵存储两条线路之间是否有共同停靠的站点,接着把邻接矩阵用邻接表存储节约之后 bfs 的时间,然后预处理 dis 数组(dis_{i,j} 表示从第 i 条线路到第 j 条线路至少需要换乘几次,注意,需要初始化整个数组为一个小于 0 的值,并且要把 dis_{i,i} 的值初始化为 0),这个过程可以通过 n 遍 bfs 来实现,最后,对于每个询问,枚举第一条线路和最后一条线路(即枚举停靠车站 s 和车站 t 的线路),输出贡献值的总和即可,设 asiz_i 中的最大值,则时间复杂度最坏情况下为 O(m^2+a^2n),可以通过本题(link)!参考代码如下:

#include<bits/stdc++.h>
using namespace std;
inline int read(){
    int x=0;
    char c=getchar();
    while(c<'0' || c>'9') c=getchar();
    while(c>='0' && c<='9'){
        x=(x<<3)+(x<<1)+(c^48);
        c=getchar();
    }
    return x;
}
const int N=100002,M=1002,inf=1e9;
vector<int> station[N];
bool con[M][M];
vector<int> c[M];
int dis[M][M];
queue<int> q;
void bfs(int st){
    dis[st][st]=0;
    q.push(st);
    while(!q.empty()){
        int x=q.front();
        q.pop();
        for(auto y:c[x]){
            if(dis[st][y]==-1){
                dis[st][y]=dis[st][x]+1;
                q.push(y);
            }
        }
    }
    return;
}
int main(){
    int n=read(),m=read(),p=read(),k,x,s,t,ans=0;
    for(int i=1;i<=m;i++){
        k=read();
        while(k--){
            x=read();
            station[x].push_back(i);
        }
    }
    for(int i=1;i<=n;i++){
        for(auto u:station[i]){
            for(auto v:station[i]){
                con[u][v]=con[v][u]=1;
            }
        }
    }
    for(int i=1;i<=m;i++){
        for(int j=1;j<=m;j++){
            if(con[i][j]){
                c[i].push_back(j);
            }
        }
    }
    memset(dis,-1,sizeof(dis));
    for(int i=1;i<=m;i++) bfs(i);
    while(p--){
        s=read(),t=read();
        int mins=inf;
        for(auto u:station[s]){
            for(auto v:station[t]){
                if(dis[u][v]!=-1){
                    mins=min(mins,dis[u][v]+1);
                }
            }
        }
        if(mins!=inf) ans+=mins;
    }
    printf("%d",ans);
    return 0;
}