题解:AT_arc197_d [ARC197D] Ancestor Relation

· · 题解

通过精细化的实现,其实可以做到 O(n^2) 的复杂度。

首先对于有解的情况,不难发现以下两个性质:

  1. 将树根据“分叉点”拆成若干条链,则矩阵中两行相同当且仅当结点在同一条链上;

  2. 一条链是另一条链的祖先当且仅当矩阵中对应元素为 1 且前者所在行的 1 的个数大于后者。

根据这两个性质,可以通过矩阵重建出树的结构,步骤如下:

首先使用字典树或基数排序将矩阵各行按照字典序进行排序,用时 O(n^2)

然后按照字典序枚举每个结点,如果当前结点对应的矩阵中的行与上一个结点完全相同,则属于同一条链,否则新开一条链。通过这一步得到了所有链的划分,每条链用时 O(n),总用时 O(n^2)

最后,对每条链,找出其父链,父链根据以下两个特征唯一确定,每条链用时 O(n),总用时 O(n^2)

  1. 矩阵中对应元素为 1

  2. 所在行的 1 的个数比当前链大且尽可能小。

为了判断无解的情况,只需按照上面的步骤在链之间连边,如果连成的不是一颗合法的树(例如有环或不连通或结点 1 所在的链不是根)则无解,如果是合法的树就重算一遍矩阵,如果与输入矩阵不一致也无解,显然这样的判断是充分的,且可以在 O(n^2) 时间内完成。

考虑如何统计答案,对于结点 1 所在的链,除了结点 1 之外,其余结点的顺序可以任意排列,对于其他链,所有结点的顺序都可以任意排列,因此答案呈现为若干个阶乘的乘积,通过预处理阶乘可以用 O(n) 时间完成。

核心代码如下:

int main() {
    int t = read();
    while (t--) {
        int n = read();
        vector a(n, vector<bool>(n));
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                a[i][j] = read() == 1;
        vector<int> sorted(n);
        iota(sorted.begin(), sorted.end(), 0);
        for (int v = n - 1; v >= 0; v--) {
            vector<int> zero, one;
            for (auto u : sorted)
                (a[u][v] ? one : zero).push_back(u);
            int cnt0 = zero.size(), cnt1 = n - cnt0;
            for (int i = 0; i < cnt1; i++)
                sorted[i] = one[i];
            for (int i = 0; i < cnt0; i++)
                sorted[i + cnt1] = zero[i];
        }
        vector<vector<int>> chain;
        vector<int> in_chain(n, -1);
        for (int r = 0; r < n; r++) {
            int i = sorted[r];
            if (r == 0 || a[i] != a[sorted[r - 1]])
                chain.emplace_back();
            in_chain[i] = chain.size() - 1;
            chain.back().push_back(i);
        }
        if (in_chain[0] != 0) {
            cout << "0\n";
            continue;
        }
        int m = chain.size();
        vector<int> deg(n);
        for (int i = 0; i < n; i++)
            deg[i] = count(a[i].begin(), a[i].end(), true);
        vector b(m, vector<bool>(m));
        for (int i = 0; i < n; i++) {
            int p = -1, mindeg = n + 1;
            for (int j = 0; j < n; j++) {
                if (!a[i][j])
                    continue;
                if (deg[j] > deg[i] && deg[j] < mindeg) {
                    p = in_chain[j];
                    mindeg = deg[j];
                }
            }
            if (p != -1)
                b[in_chain[i]][p] = b[p][in_chain[i]] = true;
        }
        int ecnt = 0;
        for (auto& bb : b)
            ecnt += count(bb.begin(), bb.end(), true);
        ecnt /= 2;
        if (ecnt != m - 1) {
            cout << "0\n";
            continue;
        }
        vector vis(m, false);
        function<void(int)> dfs1 = [&](int u) {
            vis[u] = true;
            for (int v = 0; v < m; v++)
                if (b[u][v] && !vis[v])
                    dfs1(v);
            };
        dfs1(0);
        if (count(vis.begin(), vis.end(), true) != m) {
            cout << "0\n";
            continue;
        }
        vector child(m, vector<int>());
        function<void(int, int)> dfs2 = [&](int u, int p) {
            child[u].push_back(u);
            for (int v = 0; v < m; v++)
                if (b[u][v] && v != p) {
                    dfs2(v, u);
                    for (auto ch : child[v])
                        child[u].push_back(ch);
                }
            };
        dfs2(0, -1);
        vector c(n, vector<bool>(n));
        for (int i = 0; i < m; i++) {
            for (auto u : chain[i])
                for (auto j : child[i])
                    for (auto v : chain[j])
                        c[u][v] = c[v][u] = true;
        }
        bool fail = false;
        for (int i = 0; i < n; i++)
            if (a[i] != c[i]) {
                fail = true;
                break;
            }
        if (fail) {
            cout << "0\n";
            continue;
        }
        m998 res = fac[chain[0].size() - 1];
        for (int i = 1; i < m; i++)
            res = res * fac[chain[i].size()];
        cout << res << '\n';
    }
    return 0;
}