题解:P12250 [科大国创杯初中组 2025] 旅行

· · 题解

很好的一道 bitset 练手题,考试时没有写掉有点遗憾了。

\texttt{Solution}

前置知识:bitset。

看到这个 mex 和按位或,似乎可以发现路径疲劳度的一些限制。\ 是的,不难发现,如果从某一点有一条道路通向另一点,那么路径的疲劳度只会有以下三种情况:

  1. 若该路径经过的第一条边的边权不为 0,这条路径的疲劳度为 0
  2. 若该路径经过的所有边的边权都是 0,或从第一条边的边权开始有一段连续的 0,之后是一条边权不为 1 的边,这条路径的疲劳度为 1
  3. 若该路径从第一条边的边权开始有一段连续的 0,之后是一条边权为 1 的边,这条路径的疲劳度为 2

可以发现,没有其他的情况。

这里有一点要说一下,情况 2 可以简化为某条路径的第一条边的边权为 0,然后情况 3 就会被包含在情况 2 里,这意味着在真正开始统计情况 3 之前,每一条满足情况 3 的路径已经被统计了一遍,那么统计情况 3 时就只需要对每条路径加 1 而不是 2

看看数据约定,发现有 u_i < v_i,显然,这是一个 DAG。

考虑使用 bitset 解决问题。

我们可以在原图的基础上先存一个反边,然后用拓扑排序求出每个点可以到达哪些点,这一步骤使用 bitset 将很好实现。这样,我们就可以统计出有多少条无法到达的路径,这样的路径的疲劳度是 -1

需要注意的是,只有这个部分需要用拓扑排序。

void solve() {
    for (int i = 1; i <= n; i ++ ) in4[i] = in[i];

    for (int i = 1; i <= n; i ++) {
        if(!in[i]) sta.push(i);

        bs[i].set(i, 1); // 每个点都可以到达自己本身
    }

    while (sta.size()) { // 拓扑排序
        auto u = sta.front();
        sta.pop();

        for (int i = head2[u]; i; i = nxt2[i]) {
            int j = to2[i]; // to2,head2,nxt2 均为反边。

            bs[j] |= bs[u]; // 如果点 u 可以到达某个点,那么与 u 有连边的 j 也可以到达这个点。

            if(-- in[j] == 0) sta.push(j);
        }
    }

    for (int i = 1; i <= n; i ++) {
        ans -= (n - bs[i].count()); // 一个点应该和所有的点(包括自己)有路径,所以 n - bs[i].count() 是这个点不能到达的点数。
    }
}

这样我们就完成了对疲劳度为 -1 的路径的统计。

接下来就是对疲劳度为 1 的路径的统计。

这个很好理解。对于一个点 u,如果有一条通向点 v 的边且边权为 0,那么 uv 所有可以到达的点的路径疲劳度都至少1,因为这里可能包含着情况 3

遍历所有的边,对于所有边权为 0 的边进行统计即可。

void solve2() {
    for (int i = 1; i <= n; i ++) {
        now.reset(); // 这里使用 bitset 的原因是为了防止某两点间有多条满足要求的路径而导致重复计算了疲劳速度。

        for (auto u : G[i]) {
            int j = u.first, w = u.second;

            if(w == 0) {
                now |= bs[j];
            }
        }

        ans += now.count(); // 先统计满足路径中第一条边的边权为 0 的点的数量,这样的路径疲劳度至少为 1。
    }
}

最后是对疲劳度为 2 的路径的统计,即情况 3

我们需要两步动作。

首先,统计出每个点通过一条满足第一条边的边权为 1 的路径可以到达的点的数量。\ 然后,统计有每个点只通过边权为 0 的边可以到达的点,假定有一点 u,可以通过一条边权全部为 0 的路径到达点 v,那么 u 与所有满足 v 通过一条满足第一条边的边权为 1 的路径可以到达的点之间的路径疲劳度为 2

有点绕,原谅我口才不太好,直接看看代码吧。

void solve3() {
    for (int i = 1; i <= n; i ++) {
        for (auto u : G[i]) {
            int j = u.first, w = u.second;

            if(w == 1) {
                bs1[i] |= bs[j]; // 这里统计了每一个点通过一条满足第一条边的边权为 $1$ 的路径可以到达哪些点。
            }
        }
    }
}

void solve4() {
    for (int i = n; i >= 1; i --) { // 这里有一个类似拓扑排序的顺序,所以要从大到小遍历。
        for (auto u : G[i]) {
            int j = u.first, w = u.second;

            if(w == 0) { // i 与 j 之间的边权为 0
                bs0[i] |= bs0[j]; // 如果 j 与某个点间的路径疲劳度为 2,那么 i 与这个点间的疲劳度也是 2,这里其实统计了 i 到 j 这条边的边权不为整条路经中第一段连续 0 的最后一个 0 的情况。
                bs0[i] |= bs1[j]; // 统计了 i 到 j 这条边的边权为整条路经中第一段连续 0 的最后一个 0 的情况。
            }
        }

        ans += bs0[i].count(); // 因为之前统计疲劳度为 1 的路径时已经把这些路径计算了一遍,所以这里是加 1 而不是加 2。
    }
}

最后输出即可。

\texttt{Code}

这里放一下完整的代码。

#include <bits/stdc++.h>

using namespace std;

#define int long long

const int N = 2e5 + 10, M = 3e4 + 10;

bitset <M> bs[M], bs1[M], now, bs0[N];

int n, m;
int in[N];
int idx, to[N], nxt[N], w[N], head[N];
int idx2, to2[N], nxt2[N], w2[N], head2[N];
int ans;

typedef pair <int, int> pii;

vector <pii> G[M], G2[M];

void add(int a, int b, int C) {
    idx ++;

    to[idx] = b;
    nxt[idx] = head[a];
    head[a] = idx;
    w[idx] = C;

    G[a].push_back({b, C});
}

void add2(int a, int b, int C) { // 建反边
    idx2 ++;

    to2[idx2] = b;
    nxt2[idx2] = head2[a];
    head2[a] = idx2;
    w2[idx2] = C;

    G2[a].push_back({b, C});
}

queue <int> sta;

void solve() {  
    for (int i = 1; i <= n; i ++) {
        if(!in[i]) sta.push(i);

        bs[i].set(i, 1);
    }

    while (sta.size()) {
        auto u = sta.front();
        sta.pop();

        for (int i = head2[u]; i; i = nxt2[i]) {
            int j = to2[i];

            bs[j] |= bs[u];

            if(-- in[j] == 0) sta.push(j);
        }
    }

    for (int i = 1; i <= n; i ++) {
        ans -= (n - bs[i].count());
    }
}

void solve2() {
    for (int i = 1; i <= n; i ++) {
        now.reset();

        for (auto u : G[i]) {
            int j = u.first, w = u.second;

            if(w == 0) {
                now |= bs[j];
            }
        }

        ans += now.count();
    }
}

void solve3() {
    for (int i = 1; i <= n; i ++) {
        for (auto u : G[i]) {
            int j = u.first, w = u.second;

            if(w == 1) {
                bs1[i] |= bs[j];
            }
        }
    }
} 

void solve4() {
    for (int i = n; i >= 1; i --) {
        for (auto u : G[i]) {
            int j = u.first, w = u.second;

            if(w == 0) {
                bs0[i] |= bs0[j];
                bs0[i] |= bs1[j];
            }
        }

        ans += bs0[i].count();
    }
}

signed main(void) {
    cin >> n >> m;

    for (int i = 1; i <= m; i ++) {
        int u, v, c;
        cin >> u >> v >> c;

        add(u, v, c);
        add2(v, u, c); 

        in[u] ++;
    }

    solve();
    solve2();
    solve3(); 
    solve4();   

    cout << ans << endl; 

    return 0;
}