题解 CF1326F2 【Wise Men (Hard Version)】

· · 题解

宣传 blog(Github Pages 上不去可以用 Gitee 上的镜像)

题目传送门

题解

发现既有有边又有没边的限制非常难搞,于是我们考虑容斥,即计算 ans_A 为对于所有 A 中为 1 的位置 i,满足 p_ip_{i+1} 有边的排列 p 的数量(对于 0 的位置没有限制)。

计算出 ans_A 后,只要用类似于「逆高维前缀和」的方法,就可以在 O(2^nn) 求出真正的答案。

考虑如何计算 ans_A。发现 ans_A 的值只与 A1 连续段的长度形成的可重集合有关。如 A=\{0,1,1,1,0,0,1\} 时,所有 1 连续段长度形成的集合就是 \{1,1,2,4\}(原长加 1)。

可以发现,对于任意 A,对应的 1 连续段长度形成的可重集合中所有元素的和一定恰好为 n。这意味着不同的集合数量等于 n 的划分数。

对于一个划分 P=\{a_1,a_2,\cdots,a_k\},可以发现对应的方案数实质上是在原图中选出 k 条不相交的简单路径,满足第 i 条路径的点数为 a_i 的方案数。

那么我们先考虑选 1 条路径。可以用 DP 求出 f_{u,S} 表示当前在节点 u,经过的点集为 S 时的路径条数,进而求出 g_S 表示经过的点集为 S 的路径条数。

那么对于一个划分 P=\{a_1,a_2,\cdots,a_k\},答案即为

\sum_{S_1,S_2,\cdots,S_k} \left[|S_i|=a_i\right]\left[\forall i\ne j:S_i\cap S_j=\varnothing\right]\left[\bigcup\limits_{i=1}^{k}S_i=U\right]\prod\limits_{i=1}^{k} g_{S_i}

其中 U 表示全集。

注意到由于 \sum\limits_{i=1}^{k} a_i=n,后面两个条件只需要保留一个即可。第二个条件很难处理,我们考虑保留第三个。若不考虑前两个条件,只考虑第三个条件,只需要用 FWT 即可方便地求出答案。

存在第一个条件时,我们只要将 g 加上一维,g_{l,S} 表示长度为 l,经过点集为 S 的简单路径数量。对于 |S|\ne l 的位置就置为 0。那么答案式子变成了

\sum_{S_1,S_2,\cdots,S_k} \left[\bigcup\limits_{i=1}^{k}S_i=U\right]\prod\limits_{i=1}^{k} g_{a_i,S_i}

直接将所有 g_i 进行一次 FWT 后,枚举划分时直接点乘。由于我们实际只需要卷积结果中 2^n-1 位置的值,所以我们不需要每次 IFWT,直接根据实际意义用容斥就可以 O(2^n) 求出这个值。

总时间复杂度 O(2^n(P(n)+n^2)),其中 P(n) 表示 n 的划分数。

实现得不好可能会变成 O(2^n(S(n)+n^2))O(2^nn(P(n)+n)),其中 S(n) 表示 n 的所有划分的大小之和。

代码

int n;
std::vector<int> bitcnt;
std::vector<std::vector<long long>> f;
std::vector<std::vector<long long>> g;
std::map<std::vector<int>, std::vector<int>> part;
std::vector<long long> ans;
void dfs(int r, const std::vector<int> &l, const std::vector<long long> &s){
    if (!r){
        long long res = 0;
        int mask = (1 << n) - 1;
        for (register int i = 0; i < (1 << n); ++i)
            if (bitcnt[mask ^ i] & 1) res -= s[i]; else res += s[i];
        const std::vector<int> &p = part[l];
        for (int v : p) ans[v] += res;
        return;
    }
    int lst = l.empty() ? 1 : l.back();
    for (register int k = lst; k <= r; ++k){
        if (k < r && r - k < k) continue;
        std::vector<int> nl(l);
        std::vector<long long> ns(s);
        nl.push_back(k);
        for (register int i = 0; i < (1 << n); ++i) ns[i] *= g[k - 1][i];
        dfs(r - k, nl, ns);
    }
}
void solve(){
    read(n);
    std::vector<std::vector<int>> E(n);
    for (register int i = 0; i < n; ++i)
        for (register int j = 0; j < n; ++j){
            char ch;
            while (!isdigit(ch = getchar())) ;
            if (ch == '1') E[i].push_back(j);
        }
    bitcnt.resize(1 << n);
    bitcnt[0] = 0;
    for (register int i = 1; i < (1 << n); ++i) bitcnt[i] = bitcnt[i >> 1] + (i & 1);
    f.resize(n, std::vector<long long>(1 << n, 0));
    g.resize(n, std::vector<long long>(1 << n, 0));
    for (register int i = 0; i < n; ++i) f[i][1 << i] = 1;
    for (register int S = 1; S < (1 << n); ++S)
        for (register int i = 0; i < n; ++i)
            if (f[i][S]){
                g[bitcnt[S] - 1][S] += f[i][S];
                for (int v : E[i]) if (!(S >> v & 1)) f[v][S | (1 << v)] += f[i][S];
            }
    for (register int k = 0; k < n; ++k)
        for (register int i = 0; i < n; ++i)
            for (register int S = 0; S < (1 << n); ++S)
                if (S >> i & 1) g[k][S] += g[k][S ^ (1 << i)];
    for (register int S = 0; S < (1 << (n - 1)); ++S){
        std::vector<int> p;
        int x = 1;
        for (register int i = 0; i < n - 1; ++i)
            if (S >> i & 1) ++x; else p.push_back(x), x = 1;
        p.push_back(x);
        std::sort(p.begin(), p.end());
        part[p].push_back(S);
    }
    ans.resize(1 << (n - 1), 0);
    dfs(n, std::vector<int>(), std::vector<long long>(1 << n, 1));
    for (register int i = 0; i < n - 1; ++i)
        for (register int S = 0; S < (1 << (n - 1)); ++S)
            if (S >> i & 1) ans[S ^ (1 << i)] -= ans[S];
    for (long long v : ans) print(v, ' ');
    putchar('\n');
}