题解:AT_joisc2017_b 港湾設備 (Port Facility)

· · 题解

该题解写得很详细,图比较多,建议不怎么擅长数据结构的阅读。

I. 初步分析得到 O(N^2) 做法

下面将一个物品入栈出栈看作一个区间,具体地,如果一个物品 iA_i 时刻入栈,B_i 时刻出栈,那么这个物品在时间轴上形成一个区间 [A_i,B_i]

我们先考虑如果只有一个栈,那么其入栈出栈是合法的,当且仅当对于里面的任意两个区间,它们存在包含关系不相交

从两面证明条件的充要性:

那么我们可以暴力 O(N^2) 枚举两个物品的区间,如果它们相交不包含,则在两物品间连一条边。

现在题目要求将每个物品放入 A,B 两个集合中的一个,使得在同一个集合内的物品没有边相连。

这使我们想到二分图。

那么我们暴力建图之后,如果该图不是二分图,则一定无解,答案为 0

否则答案为二分图合法黑白染色数量,我们考虑二分图的每个联通块都有恰好 2 种染色方式,我们先对该二分图计算其联通块数量 cnt,则答案为 2^{cnt}

至此,我们获得了一个 O(N^2) 的做法。

II. 减少边数,线段树优化

然而我们发现,这张图建出来,边数可能会是 O(N^2) 级别的:

所以我们如果要优化上述算法,应该从去掉一些不必要的边入手,毕竟我们只需要知道原图是否为二分图,并计算其连通块数量。

考虑前面 O(N^2) 做法,两个物品连边的条件:

那么我们有一种想法:

现在这个做法已经有了区间查询的雏形了,我们考虑用线段树进行优化:

这样连边的正确性证明:

对于时间复杂度的分析:

综上所述,该算法时间复杂度为 O(N \log N)

code(代码使用种类并查集代替遍历,复杂度和上述分析略有不同)

#include <bits/stdc++.h>
#define pii pair<int,int>
using namespace std;
const int mo = 1e9 + 7;
int n, f[2000001];

inline int find(int x) {
    return (x == f[x] ? x : f[x] = find(f[x]));
}

inline void merge(int x, int y) {
    f[find(x)] = find(y + n);
    f[find(x + n)] = find(y);
}

inline bool cross(pii A, pii B) {
    return ((A.first > B.first && A.first < B.second) ^ (A.second > B.first && A.second < B.second));
}
pii a[1000001];

#define lc (x<<1)
#define rc (x<<1|1)
#define mid ((l+r)>>1)
vector<int> v[8000001];
int ed[8000001];

inline void upd(int x, int l, int r, int id) {
    v[x].push_back(id);

    if (l == r)
        return;

    if (a[id].second <= mid)
        upd(lc, l, mid, id);
    else
        upd(rc, mid + 1, r, id);
}

inline void query(int x, int l, int r, int id) {
    if (l > a[id].second || r < a[id].first)
        return;

    if (l >= a[id].first && r <= a[id].second) {
        if (!ed[x] && !v[x].empty())
            ed[x] = v[x].back();

        if (cross(a[id], a[ed[x]]))
            merge(id, ed[x]);

        while (!v[x].empty() && cross(a[id], a[v[x].back()]))
            merge(id, v[x].back()), v[x].pop_back();

        return;
    }

    query(lc, l, mid, id);
    query(rc, mid + 1, r, id);
}

int main() {
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n;

    for (int i = 1; i <= n + n; ++i)
        f[i] = i;

    for (int i = 1; i <= n; ++i)
        cin >> a[i].first >> a[i].second;

    sort(a + 1, a + n + 1);

    for (int i = n; i; --i)
        upd(1, 1, n + n, i);

    for (int i = 1; i <= n; ++i)
        query(1, 1, n + n, i);

    int cnt = 0;

    for (int i = 1; i <= n; ++i)
        if (find(i) == i)
            ++cnt;

    for (int i = 1; i <= n; ++i)
        if (find(i) == find(i + n)) {
            cout << 0 << '\n';
            return 0;
        }

    int ans = 1;

    while (cnt--)
        ans = ans * 2 % mo;

    cout << ans << '\n';

    return 0;
}