P11054 [IOI2024] 斯芬克斯的谜题

· · 题解

自己做只会 47.5,感觉这题正解没有消息篡改者难但还是没有想到。

与正解无关的部分分做法在此就不赘述了。

首先考虑 50\% 怎么做,考虑一个增量构造的方法,如果我们已经确定的 1 \sim u 的导出子图的颜色相等关系,那么考虑加入点 u + 1,可以考虑二分出 1 \sim u 中第一个和 u + 1 有连边并且颜色相等的点(这个的原理是如果没有点和 u + 1 有边相连且颜色相等那么强制 u + 1 形成一个 单色分支,这样以来返回值就会达到最大,每次只需要判断一下 单色分支 的数量是否符合我们的预期即可,可以区分两种情况,而剩下我们不想考虑的点直接染成 N 即可),如果没有二分出来,那么 1 \sim u + 1 已经构造完毕,否则我们就让 u + 1 和二分出的点 v 连边(形成一个连通块),然后继续二分。注意每次二分前要判断一下全局是否有点满足条件,如果没有,就不用二分浪费 \log N 次询问了!

设最终连出了 x 个连通块,那么精确的询问次数就是 (N - x) \log N + N 的,接下来我们考虑怎么确定每个连通块的颜色。

这里我们考虑建出一棵连通块缩点后的生成树,将其黑白染色,容易发现黑点之间和白点之间如果有连边那么不可能有颜色相等的情况,所以我们排除了颜色之间的干扰,接下来我们要确定白点之间所有颜色为 c 的点,可以考虑将黑点全部染成 c,这样每个白点一定和一个颜色为 c 的黑点相连(这就是为什么要用生成树的结构!),然后二分出还没有确定颜色的白点中编号最小的且颜色为 c 的白点,这个判定可以用 50 分的做法的思想来做。

注意 x = 1 要特判。

分析一下次数:因为对于黑白点都要做一遍,并且每种颜色也有一次总的判定所以次数是 x \log N + 2N 的,总次数是 N \log N + 3N \approx 2750,可以通过。

#include "sphinx.h"
#include <bits/stdc++.h>

using namespace std;

int perform_experiment(vector<int> E);

vector<int> find_colours(int N, vector<int> X, vector<int> Y) {
    vector<pair<int, int>> st;
    int M = X.size();
    for (int i = 0; i < M; ++ i) st.push_back(make_pair(X[i], Y[i]));

    vector<int> dsu(N);
    iota(dsu.begin(), dsu.end(), 0);
    auto findFa = [&](auto&& self, int u) -> int {
        if (dsu[u] == u) {
            return u;
        } else {
            dsu[u] = self(self, dsu[u]);
            return dsu[u];
        }
    };
    auto LinkEdge = [&](int u, int v) -> void {
        u = findFa(findFa, u);
        v = findFa(findFa, v);
        if (u != v) {
            dsu[u] = v;
        }
    };

    vector<int> dsu2(N);
    iota(dsu2.begin(), dsu2.end(), 0);
    auto findFa2 = [&](auto&& self, int u) -> int {
        if (dsu2[u] == u) {
            return u;
        } else {
            dsu2[u] = self(self, dsu2[u]);
            return dsu2[u];
        }
    };
    auto LinkEdge2 = [&](int u, int v) -> void {
        u = findFa2(findFa2, u);
        v = findFa2(findFa2, v);
        if (u != v) {
            dsu2[u] = v;
        }
    };

    auto query = [&](vector<int> ord) -> int {
        iota(dsu2.begin(), dsu2.end(), 0);
        for (auto [U, V] : st) {
            if (ord[U] == ord[V]) {
                LinkEdge2(U, V);
            }
        }
        int cnt = 0;
        for (int i = 0; i < N; ++ i) {
            if (findFa2(findFa2, i) == i) {
                cnt ++;
            }
        }
        return cnt;
    };

    vector<int> ans(N);
    iota(dsu.begin(), dsu.end(), 0);
    int tot = 0;
    for (int i = 0; i < N; ++ i) {
        vector<int> que(N);
        int upp = 0;
        while (upp < i) {
            vector<int> vc;
            for (int j = upp; j < i; ++ j) {
                if (findFa(findFa, j) != findFa(findFa, i)) {
                    vc.push_back(j);
                }
            }
            if (vc.empty()) break;
            int L = 0, R = int(vc.size()) - 1, pos = i;
            tot ++;
            vector<int> ans2(N);
            for (int j = 0; j < N; ++ j) {
                if (findFa(findFa, j) == findFa(findFa, i) || (upp <= j && j <= vc[R])) {
                    que[j] = -1;
                    ans2[j] = (findFa(findFa, j) == findFa(findFa, i) ? N : findFa(findFa, j));
                } else {
                    que[j] = N;
                    ans2[j] = -1;
                }
            }
            if (perform_experiment(que) == query(ans2)) {
                break;
            }
            while (L <= R) {
                int mid = (L + R) / 2;
                for (int j = 0; j < N; ++ j) {
                    if (findFa(findFa, j) == findFa(findFa, i) || (upp <= j && j <= vc[mid])) {
                        que[j] = -1;
                        ans2[j] = (findFa(findFa, j) == findFa(findFa, i) ? N : findFa(findFa, j));
                    } else {
                        que[j] = N;
                        ans2[j] = -1;
                    }
                }
                if (perform_experiment(que) != query(ans2)) {
                    pos = vc[mid];
                    R = mid - 1;
                } else {
                    L = mid + 1;
                }
            }
            if (pos == i) break;
            LinkEdge(i, pos);
            upp = pos + 1;
        }
    }

    bool only = true;
    for (int i = 1; i < N; ++ i) {
        only &= (findFa(findFa, 0) == findFa(findFa, i));
    }
    if (only) {
        int ansC = -1;
        for (int C = 0; C < N; ++ C) {
            vector<int> que(N);
            que[0] = -1;
            for (int x = 1; x < N; ++ x) {
                que[x] = C;
            }
            if (perform_experiment(que) == 1) {
                ansC = C;
                break;
            }
        }
        for (auto& x : ans) x = ansC;

        return ans;
    }

    vector<int> mp(N);
    vector<int> bipartite(N);
    for (auto& x : mp) x = -1;
    for (auto& x : bipartite) x = -1;
    vector<vector<int>> vertex(N);
    for (auto& x : vertex) x.clear();
    for (int i = 0; i < N; ++ i) vertex[findFa(findFa, i)].push_back(i);
    iota(dsu2.begin(), dsu2.end(), 0);
    for (auto [U, V] : st) {
        int X_ = findFa(findFa, U);
        int Y_ = findFa(findFa, V);
        if (findFa2(findFa2, X_) != findFa2(findFa2, Y_)) {
            if (bipartite[X_] == -1 && bipartite[Y_] == -1) {
                bipartite[X_] = 0;
                bipartite[Y_] = 1;
            } else if (bipartite[X_] == -1) {
                bipartite[X_] = 1 - bipartite[Y_];
            } else if (bipartite[Y_] == -1) {
                bipartite[Y_] = 1 - bipartite[X_];
            } else if (bipartite[X_] == bipartite[Y_]) {
                for (int i = 0; i < N; ++ i) {
                    if (findFa2(findFa2, i) == findFa2(findFa2, X_)) {
                        bipartite[i] ^= 1;
                    }
                }
            }
            LinkEdge2(X_, Y_);
        }
    }

    for (int curX = 0; curX < 2; ++ curX) {
        for (int C = 0; C < N; ++ C) {
            vector<int> que(N);
            vector<int> ans2(N);
            for (auto& x : ans2) x = -1;
            for (auto& x : que) x = -1;
            vector<int> vc, vc2;
            for (int i = 0; i < N; ++ i) {
                if (bipartite[i] == 1 - curX) {
                    for (auto u : vertex[i]) {
                        que[u] = C;
                    }
                } else if (bipartite[i] == curX) {
                    if (!vertex[i].empty()) {
                        if (mp[vertex[i][0]] == -1) {
                            vc.push_back(i);
                        } else {
                            vc2.push_back(i);
                        }
                    }
                }
            }
            if (!vc.empty()) {
                sort(vc.begin(), vc.end());

                int upp = 0;
                while (upp < int(vc.size())) {
                    auto check = [&](int x) -> bool {
                        for (int i = 0; i < int(vc.size()); ++ i) {
                            if (upp <= i && i <= x) {
                                for (auto u : vertex[vc[i]]) {
                                    que[u] = -1;
                                    ans2[u] = findFa(findFa, u);
                                }
                            } else {
                                for (auto u : vertex[vc[i]]) {
                                    que[u] = N;
                                    ans2[u] = N;
                                }
                            }
                        }
                        for (auto x : vc2) {
                            for (auto u : vertex[x]) {
                                que[u] = N;
                                ans2[u] = N;
                            }
                        }
                        if (perform_experiment(que) != query(ans2)) {
                            return true;
                        } else {
                            return false;
                        }
                    };

                    int L = upp, R = int(vc.size()) - 1, pos = int(vc.size());
                    if (!check(R)) break;
                    while (L <= R) {
                        int mid = (L + R) / 2;
                        if (check(mid)) {
                            pos = mid;
                            R = mid - 1;
                        } else {
                            L = mid + 1;
                        }
                    }
                    for (auto u : vertex[vc[pos]]) {
                        mp[findFa(findFa, u)] = C;
                    }
                    upp = pos + 1;
                }
            }
        }
    }

    for (int i = 0; i < N; ++ i) {
        ans[i] = mp[findFa(findFa, i)];
    }

    return ans;
}