题解:P10947 Sightseeing

· · 题解

题解:P10947 Sightseeing

题意

一句话概括:有向图最短路及其长度加一次短路计数。

思路

次短路模版:P2865。

最短路计数模版:P1144。

没有做过以上两题的可以先做一下。下面用我做这道题的完整过程分步骤讲解:

设起点为 u,终点为 v

1. 最短路求解:设 dis_{0, t} 为从 ut 的最短路,可以用堆优化 dijkstra 算法实现。答案为 dis_{0, v}

2. 最短路计数:设 cnt_{0, t} 为从 ut 的最短路计数,在 dijkstra 算法中:如果最短路会更新为更短的,就覆盖掉原有的 cnt,如果新的路径恰好等于最短路,就给原有的 cnt 加上新路径的计数。答案为 cnt_{0, v}

3. 次短路求解:设 dis_{1, t} 为从 ut 的次短路,这里需要在 dijkstra 算法中加上一些内容:如果最短路 w_0 需要更新,那么先令 dis_{1, now}= w_0,再更新最短路;如果最短路不需要更新,设新路径为 w_1,若 dis_{1, now} > w_1,更新次短路。答案为 dis_{1, v}

4. 次短路计数:设 cnt_{1, t} 为从 ut 的次短路计数,这里需要再用最短路计数的搞法再在 dijkstra 算法中加上类似的内容,不再重复。答案为 cnt_{1,v}

5. 最终答案:如果 dis_{1, v} = dis_{0, v} + 1,那么答案为 cnt_{0, v} + cnt_{1, v},否则答案为 cnt_{0, v}

6. 实测:我用这样的思路做完,发现没有通过!(如果你在这里已经过了那请跳过这部分)此时的核心代码长这样:

while (!q.empty()) {
        qq tt = q.top();
        q.pop();
        int y = tt.pos, d = tt.d1;
        if (vis[y]) {
            continue;
        }
        vis[y] = 1;
        for (int i = h[y]; i; i = e[i].nxt) {
            int j = e[i].to, tw = e[i].w;
            if (dis[0][j] > d + tw) {
                dis[1][j] = dis[0][j];
                cnt[1][j] = cnt[0][j];
                dis[0][j] = d + tw;
                cnt[0][j] = cnt[0][y];
                q.push((qq) {
                    dis[1][j], j
                });
                q.push((qq) {
                    dis[0][j], j
                });
            } else if (dis[0][j] == d + tw) {
                cnt[0][j] += cnt[0][y];
            } else if (dis[1][j] > d + tw) {
                dis[1][j] = d + tw;
                cnt[1][j] = cnt[0][y];
                q.push((qq) {
                    dis[1][j], j
                });
            } else if (dis[1][j] == d + tw) {
                cnt[1][j] += cnt[0][y];
            }
        }
    }

仔细看了一会,会发现一个问题:为什么更新更新 cnt 用的都是其他的 cnt_0?看起来这就不合理,但是又应该怎么更新?

7. 解决问题:既然不知道是用哪个计数器更新,不妨在堆的 node 中加一个变量 op \in \left \{ 1, 2 \right \} 记录一下这个点是用来更新过最短路还是次短路,与此同时,vis 数组就也要加一维,判断的是同样的内容。每次取出堆顶,如果 vis_{op, now} = 1,就跳过,更新时把刚才代码里所有的 cnt_{0, y} 换成 cnt_{op, y},这样就通过了本题。

此时的核心代码:

while (!q.empty()) {
        qq tt = q.top();
        q.pop();
        int y = tt.pos, x = tt.op, d = tt.d1;
        if (vis[x][y]) {
            continue;
        }
        vis[x][y] = 1;
        for (int i = h[y]; i; i = e[i].nxt) {
            int j = e[i].to, tw = e[i].w;
            if (dis[0][j] > d + tw) {
                dis[1][j] = dis[0][j];
                cnt[1][j] = cnt[0][j];
                dis[0][j] = d + tw;
                cnt[0][j] = cnt[x][y];
                q.push((qq) {
                    dis[1][j], 1, j
                });
                q.push((qq) {
                    dis[0][j], 0, j
                });
            } else if (dis[0][j] == d + tw) {
                cnt[0][j] += cnt[x][y];
            } else if (dis[1][j] > d + tw) {
                dis[1][j] = d + tw;
                cnt[1][j] = cnt[x][y];
                q.push((qq) {
                    dis[1][j], 1, j
                });
            } else if (dis[1][j] == d + tw) {
                cnt[1][j] += cnt[x][y];
            }
        }
    }

细节问题

1. 本题多测,记得清空所有数组和变量。

2. 最短路计数需要初始化 cnt_{0, u} = 1,而次短路计数不需要。

3. 注意 dijkstra 算法中判断、处理的先后顺序,避免出现要用的数据先被覆盖了的情况。

code

#include <bits/stdc++.h>
using namespace std;
const int N = 1e4 + 5;

int T, n, m, s, t, h[N], tot, dis[2][N], cnt[2][N], vis[2][N];
struct node {
    int to, nxt, w;
} e[N];

struct qq {
    int d1, op, pos;
    bool operator <(const qq &k)const {
        return k.d1 < d1;
    }
};

void add(int u, int v, int w) {
    e[++tot].to = v;
    e[tot].w = w;
    e[tot].nxt = h[u];
    h[u] = tot;
}

int dij(int u, int v) {
    memset(dis, 0x3f, sizeof(dis));
    memset(vis, 0, sizeof(vis));
    memset(cnt, 0, sizeof(cnt));
    cnt[0][u] = 1;
    dis[0][u] = 0;
    priority_queue<qq> q;
    q.push((qq) {
        0, 0, u
    });
    while (!q.empty()) {
        qq tt = q.top();
        q.pop();
        int y = tt.pos, x = tt.op, d = tt.d1;
        if (vis[x][y]) {
            continue;
        }
        vis[x][y] = 1;
        for (int i = h[y]; i; i = e[i].nxt) {
            int j = e[i].to, tw = e[i].w;
            if (dis[0][j] > d + tw) {
                dis[1][j] = dis[0][j];
                cnt[1][j] = cnt[0][j];
                dis[0][j] = d + tw;
                cnt[0][j] = cnt[x][y];
                q.push((qq) {
                    dis[1][j], 1, j
                });
                q.push((qq) {
                    dis[0][j], 0, j
                });
            } else if (dis[0][j] == d + tw) {
                cnt[0][j] += cnt[x][y];
            } else if (dis[1][j] > d + tw) {
                dis[1][j] = d + tw;
                cnt[1][j] = cnt[x][y];
                q.push((qq) {
                    dis[1][j], 1, j
                });
            } else if (dis[1][j] == d + tw) {
                cnt[1][j] += cnt[x][y];
            }
        }
    }
    if (dis[0][v] + 1 == dis[1][v]) {
        return cnt[0][v] + cnt[1][v];
    } else {
        return cnt[0][v];
    }
}

int main() {
    cin >> T;
    while (T--) {
        cin >> n >> m;
        memset(h, 0, sizeof(h));
        tot = 0;
        while (m--) {
            int u, v, w;
            cin >> u >> v >> w;
            add(u, v, w);
        }
        cin >> s >> t;
        cout << dij(s, t) << endl;
    }
    return 0;
}