题解:P12319 [蓝桥杯 2024 国研究生组] 最短路

· · 题解

P12319 [蓝桥杯 2024 国研究生组] 最短路 - 题解

前置知识:分层图最短路,矩阵加速。

在尝试此题前建议先完成 P2886 [USACO07NOV] Cow Relays G。

正文

首先,题中提到了一项机制:对于每次询问,可以把一条边的边权整除 2,但每次询问只有可以进行一次此操作(即使多次经过一条边也只生效一次),所以考虑拆点:每个点有两种状态:进行过整除操作和没进行过,对拆开的两个点分别连边。

假设题中从 uv 连了一条有向边,边权为 w,则对于拆出的点 u' 和点 v',我们需要额外从 uv' 连一条边权为 \frac{w}{2} 的边,代表在此处进行整除操作,当然,也要从 u'v' 连一条边权为 w 的边,代表进行过整除操作操作的情况。

由于题中的 c_i \le 10^9,所以使用矩阵乘法进行加速,其余部分就和 P2886 一样了……吗?

由于这道题需要进行 m 次操作,而 m \le 1000, n \le 50,同时我们还进行了拆点,所以矩阵大小为 100 \times 100,则总的时间复杂度为 O((2 \times n)^3 \times m\log{c_i}),不能通过此题。

注意到,我们可以使用倍增进行预处理,处理出经过 2^0 条边到 2^{30} 条边的矩阵,然后对每次询问的 c_i 进行拆位计算,减少矩阵乘法的次数。但这种优化实际上只减少了常数时间复杂度,并没有改变根本。然而,由于每次询问都给出了起点 s,我们可以把初始矩阵中的第 s 行单独提取出来,然后再与倍增预处理出的矩阵进行矩阵乘法,这样单次计算的时间复杂度由 O((2 \times n)^3) 变成了 O((2 \times n)^2),同时保证了以 s 为起点的值正确,此题解决。

时间复杂度:预处理部分为 O(30 \times (2\times n)^3),查询部分为 O(m \times (2 \times n)^2 \times \log{c_i})

以下为代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;//记得开 long long 
const ll inf=0x3f3f3f3f3f3f3f3f;
struct node
{
    int n,m;
    ll mp[105][105];
    node(){memset(mp,0x3f,sizeof(mp));}//在使用了 node(int x,int y) 这一构造函数后,直接开node数组无法通过编译,所以要加这一行 
    node(int x,int y){n=x,m=y,memset(mp,0x3f,sizeof(mp));}
    node operator * (node x)
    {
        int z=x.m;
        node res(n,z);
        for(int i=1;i<=n;i++)
            for(int j=1;j<=z;j++)
                for(int k=1;k<=m;k++)
                    res.mp[i][j]=min(res.mp[i][j],mp[i][k]+x.mp[k][j]);//广义矩阵乘法 
        return res;
    }
};
node to[35];
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    int n,q;
    cin>>n;
    node a(n*2,n*2);
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++)
        {
            int x;
            cin>>x;
            if(x!=0)
            {
                a.mp[i][j]=x;
                a.mp[i][j+n]=x/2;
                a.mp[i+n][j+n]=x;//拆点 
            }
        }
    }
    to[0]=a;
    for(int i=1;i<=30;i++)
        to[i]=to[i-1]*to[i-1];//倍增 
    cin>>q;
    while(q--)
    {
        int s,t,k;
        cin>>s>>t>>k;
        node st(1,n*2);
        for(int i=1;i<=n*2;i++)
            st.mp[1][i]=a.mp[s][i];
        k--;//k要减一,因为 k==1 为初始矩阵的情况,而进行一次乘法后为经过 2 条边的情况 
        for(int i=0;i<=30;i++)
            if(k&(1<<i))
                st=st*to[i];
        ll ans=min(st.mp[1][t],st.mp[1][t+n]);
        cout<<(ans==inf?-1:ans)<<"\n";
    }
    return 0;
}