题解:AT_abc450_f [ABC450F] Strongly Connected 2

· · 题解

ABC450F 题解

什么?图论题?好难啊!

现将问题进行转化。当且仅当能够从 1 走到 n,这个图是强连通的。

:::success[证明] 如果从 1 走不到 n,显然是不行的。

如果从 1 能走到 n,对于 u<v 两个点:

vu,沿着 v\to v-1\to v-2\to\dots\to u 走即可;

uv,先从 u 走到 1,再从 1 走到 n,最后从 n 走到 v。 :::

于是这个题就一点都不图论了。记 l 为起点、r 为终点,则它合法就当且仅当存在一组边 e_1,e_2,\dots,e_k 使得 l_{e_1}=1,r_{e_k}=n,l_{e_i}\leq l_{e_{i+1}}\leq r_{e_i}\leq r_{e_{i+1}}。其实就类似于线段覆盖问题。

方法一:直接 DP + 线段树优化

现在我们要按照一定顺序考虑这些边,使得每一条这种通路都能按照顺序被考虑每条边。显然按照起点排序即可。

:::warning[如何排序] 其实你只需要保证 l_{u}\leq l_{v}\leq r_{u}\leq r_{v}u 排在 v 前就行。

除了按照起点排序、终点排序,你甚至可以按照 k_1l+k_2r 排序,其中 k_1,k_2\geq0。 :::

所以说,在考虑过前 i 条边后,我们只关心从 1 最远能走到谁。就算后面有不相邻的边也不影响,因为再选择能够从 1 走过来的边之后,能到的最往后的点一定是新选的这个终点(否则与排序条件矛盾)。

于是考虑设计状态 f_{i,j}:考虑了前 i 条边,从 1 最远能够到达 j,总有有多少种方案。初始:f_{0,1}=1;答案:f_{m,n}。仔细刻画转移(i\to i+1)如下:

对于 j<l_i,这个边选不选都不影响,因为走不过来,等走过来了就不是最远的了,所以 f_{i,j}=2f_{i-1,j}

对于 j>r_i,这个边选不选都不影响,因为它这个终点就不是最远能到的,因此 f_{i,j}=2f_{i-1,j}

对于 l_i\leq j<r_i,这个边不能选,否则会到达 r_i,因此 f_{i,j}=f_{i-1,j}

对于 j=r_i,若这个边选了则上一步只需 j'\in[l_i,r_i],否则上一步必须 j'=r_i,即 f_{i,j}=f_{i-1,j}+\sum_{j'=l_i}^{r_i}f_{i-1,j'}

发现实际上由 f_{i-1}f_i 的变化只是两个区间内整体乘以 2,一个点的单点赋值,直接用一棵线段树整体维护每个时刻的 f_i 即可。时间复杂度 \operatorname{O}(n\log n)

:::success[code]

#include <bits/stdc++.h>

using namespace std;

constexpr int N = 2e5;
constexpr int M = 2e5;
constexpr int P = 998244353;

int n, m, x[M + 1], y[M + 1];
int id[M + 1];
int sum[N * 4 + 1], tag[N * 4 + 1];

inline void set_tag(int id, int tg) {
    sum[id] = sum[id] * 1LL * tg % P;
    tag[id] = tag[id] * 1LL * tg % P;
}

inline void push_down(int id) {
    if (tag[id] == 1)
        return;
    set_tag(id * 2, tag[id]);
    set_tag(id * 2 + 1, tag[id]);
    tag[id] = 1;
}

inline void update(int id) {
    sum[id] = (sum[id * 2] + sum[id * 2 + 1]) % P;
}

inline void build(int id, int l, int r) {
    tag[id] = 1;
    if (l == r)
        sum[id] = l == 1 ? 1 : 0;
    else {
        int mid = l + r >> 1;
        build(id * 2, l, mid);
        build(id * 2 + 1, mid + 1, r);
        update(id);
    }
}

inline void change(int id, int l, int r, int q, int to) {
    if (l == r)
        sum[id] = to;
    else {
        int mid = l + r >> 1;
        push_down(id);
        if (q <= mid)
            change(id * 2, l, mid, q, to);
        else
            change(id * 2 + 1, mid + 1, r, q, to);
        update(id);
    }
}

inline void modify(int id, int l, int r, int ql, int qr, int tg) {
    if (ql > qr)
        return;
    else if (l == ql && r == qr)
        set_tag(id, tg);
    else {
        int mid = l + r >> 1;
        push_down(id);
        if (qr <= mid)
            modify(id * 2, l, mid, ql, qr, tg);
        else if (ql > mid)
            modify(id * 2 + 1, mid + 1, r, ql, qr, tg);
        else {
            modify(id * 2, l, mid, ql, mid, tg);
            modify(id * 2 + 1, mid + 1, r, mid + 1, qr, tg);
        }
        update(id);
    }
}

int query(int id, int l, int r, int ql, int qr) {
    if (ql > qr)
        return 0;
    else if (l == ql && r == qr)
        return sum[id];
    else {
        int mid = l + r >> 1;
        push_down(id);
        if (qr <= mid)
            return query(id * 2, l, mid, ql, qr);
        else if (ql > mid)
            return query(id * 2 + 1, mid + 1, r, ql, qr);
        else
            return (query(id * 2, l, mid, ql, mid) + query(id * 2 + 1, mid + 1, r, mid + 1, qr)) % P;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= m; i++)
        cin >> x[i] >> y[i];
    iota(id + 1, id + 1 + m, 1);
    sort(id + 1, id + 1 + m, [](int u, int v) {
        return x[u] < x[v];
    });
    build(1, 1, n);
    for (int i = 1; i <= m; i++) {
        int l = x[id[i]], r = y[id[i]];
        modify(1, 1, n, 1, l - 1, 2);
        modify(1, 1, n, r + 1, n, 2);
        change(1, 1, n, r, (query(1, 1, n, r, r) + query(1, 1, n, l, r)) % P);
        // cerr << i << ": " << l << ' ' << r << '\n';
        // for (int j = 1; j <= n; j++)
        //  cerr << query(1, 1, n, j, j) << " \n"[j == n];
    }
    cout << query(1, 1, n, n, n) << '\n';
    return 0;
}

:::

方法二:容斥 DP

先转化成线段问题。对于每个线段,我们先将 r_i 减小 1,则原问题就变成了在 m 个线段里,有几种从中选择若干个的方式,使得选择的这些线段能够覆盖 1,2,\dots,n-1 所有点。

那正难则反,记 \operatorname{f}(S) 表示能够覆盖 S 内所有点的方案数,\operatorname{g}(S) 表示一定不覆盖 S 内所有点的方案数,则根据容斥原理,我们有 \operatorname{f}(\{1,2,3,\dots,n\})=\sum_{S\subseteq‌\{1,2,3,\dots,n\}}\operatorname{g}(S)\cdot(-1)^{|S|}

考虑如何计算 \operatorname{g}(S)。显然是 2^{cnt},其中 cnt 表示不包含 S 中点的线段的数量,于是有了如下暴力代码。

:::error[暴力 code]

#include <bits/stdc++.h>

using namespace std;

constexpr int N = 2e5;
constexpr int M = 2e5;
constexpr int P = 998244353;

int n, m, x[M + 1], y[M + 1];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= m; i++) {
        cin >> x[i] >> y[i];
        y[i]--;
    }
    int ans = 0;
    for (int i = 0; i < (1 << n - 1); i++) {
        int cnt = 0;
        for (int j = 1; j <= m; j++) {
            bool flag = true;
            for (int k = 1; k < n; k++)
                if (i >> k - 1 & 1)
                    flag &= !(x[j] <= k && k <= y[j]);
            cnt += flag;
        }
        int ret = 1;
        for (int j = 1; j <= cnt; j++)
            ret = (ret + ret) % P;
        if (__builtin_popcount(i) & 1)
            ans = (ans + P - ret) % P;
        else
            ans = (ans + ret) % P;
    }
    cout << ans << '\n';
    return 0;
}

:::

之后考虑用容斥 DP 优化。我们走到哪里算到哪里,考虑 f_i 表示只考虑 [1,i] 内部的点和线段,i 被选中不能覆盖的所有情况的容斥值之和,则 f_i=-\sum_{j=0}^{i-1}f_j\cdot2^{cnt},其中 cnt 表示完全在 [j+1,i-1] 内的线段个数。于是有了正常 DP。

:::warning[DP code]

#include <bits/stdc++.h>

using namespace std;

constexpr int N = 2e5;
constexpr int M = 2e5;
constexpr int P = 998244353;

int n, m, x[M + 1], y[M + 1];
int dp[N + 1 + 1];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= m; i++) {
        cin >> x[i] >> y[i];
        y[i]--;
    }
    dp[0] = 1;
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j < i; j++) {
            int cnt = 1;
            for (int k = 1; k <= m; k++)
                if (j < x[k] && y[k] < i)
                    cnt = (cnt + cnt) % P;
            dp[i] = (dp[i] + dp[j] * 1LL * cnt % P * (P - 1) % P) % P;
        }
    }
    cout << (P - dp[n]) % P << '\n';
    return 0;
}

:::

然后是最后的优化。显然随着 i 的变大,cnt 需要考虑的线段会增多,于是在算完 f_i 后就把所有右端点为 i 的线段考虑进去。

我们用一棵线段树,记录在每个 i 时考虑了 [1,i-1] 内的所有线段,f_jf_i 的贡献即 f_j\cdot2^{cnt} 是多少,并支持区间求和,显然在加入一个线段 [l,i] 时就会对 [0,l-1] 做区间乘 2,算完 f_i 相当于对 i 点要做单点赋值,用类似于上一个方法的一颗线段树就能搞定,同样 \operatorname{O}(n\log n)

:::success[code]

#include <bits/stdc++.h>

using namespace std;

constexpr int N = 2e5;
constexpr int M = 2e5;
constexpr int P = 998244353;

int n, m, x[M + 1], y[M + 1];
vector<int> vec[N + 1];
int sum[N * 4 + 1], tag[N * 4 + 1], dp[N + 1];

inline void set_tag(int id, int tg) {
    sum[id] = sum[id] * 1LL * tg % P;
    tag[id] = tag[id] * 1LL * tg % P;
}

inline void push_down(int id) {
    if (tag[id] == 1)
        return;
    set_tag(id * 2, tag[id]);
    set_tag(id * 2 + 1, tag[id]);
    tag[id] = 1;
}

inline void update(int id) {
    sum[id] = (sum[id * 2] + sum[id * 2 + 1]) % P;
}

inline void build(int id, int l, int r) {
    tag[id] = 1;
    if (l == r)
        sum[id] = l == 1 ? 1 : 0;
    else {
        int mid = l + r >> 1;
        build(id * 2, l, mid);
        build(id * 2 + 1, mid + 1, r);
        update(id);
    }
}

inline void change(int id, int l, int r, int q, int to) {
    if (l == r)
        sum[id] = to;
    else {
        int mid = l + r >> 1;
        push_down(id);
        if (q <= mid)
            change(id * 2, l, mid, q, to);
        else
            change(id * 2 + 1, mid + 1, r, q, to);
        update(id);
    }
}

inline void modify(int id, int l, int r, int ql, int qr, int tg) {
    if (ql > qr)
        return;
    else if (l == ql && r == qr)
        set_tag(id, tg);
    else {
        int mid = l + r >> 1;
        push_down(id);
        if (qr <= mid)
            modify(id * 2, l, mid, ql, qr, tg);
        else if (ql > mid)
            modify(id * 2 + 1, mid + 1, r, ql, qr, tg);
        else {
            modify(id * 2, l, mid, ql, mid, tg);
            modify(id * 2 + 1, mid + 1, r, mid + 1, qr, tg);
        }
        update(id);
    }
}

int query(int id, int l, int r, int ql, int qr) {
    if (ql > qr)
        return 0;
    else if (l == ql && r == qr)
        return sum[id];
    else {
        int mid = l + r >> 1;
        push_down(id);
        if (qr <= mid)
            return query(id * 2, l, mid, ql, qr);
        else if (ql > mid)
            return query(id * 2 + 1, mid + 1, r, ql, qr);
        else
            return (query(id * 2, l, mid, ql, mid) + query(id * 2 + 1, mid + 1, r, mid + 1, qr)) % P;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= m; i++) {
        cin >> x[i] >> y[i];
        vec[--y[i]].push_back(x[i]);
    }
    build(1, 1, n);
    for (int i = 1; i <= n; i++) {
        dp[i] = (P - query(1, 1, n, 1, i)) % P;
        if (i < n) {
            change(1, 1, n, i + 1, dp[i]);
            for (int j : vec[i])
                modify(1, 1, n, 1, j, 2);
        }
    }
    cout << (P - dp[n]) % P << '\n';
    return 0;
}

:::

总结

本题中方法一无疑是最好的方法,但是大多数题,如果 DP 的状态就是 \operatorname{O}(n^2) 的,实际上能够优化到 \operatorname{O}(n) 左右的可能性不大。此题实际上是借用了线段树整体转移状态,而不将每个状态一一算出(十分类似于 CSP-S2024C)。如果想要在都算完之后迅速求出过程中某个状态的值,可以用可持久化线段树来实现。

对于方法二,实际上在此题中有些大材小用。线段覆盖相关问题正着做其实是很困难的,因为在每个时刻需要记录大于 \operatorname{O}(n) 种信息,不好直接计数,但如果通过容斥,正难则反,直接转化成十分简约的统计问题,虽然有 \operatorname{O}(2^n) 规模的子问题,但是可以在 DP 的过程中进行合并,并且每个子问题的处理模式非常单一,也是可做的。可以说容斥 DP 有很广泛的应用。

(于 202666 日修改了错误的容斥系数)