题解 CF1326F2 【Wise Men (Hard Version)】

· · 题解

个人感觉写的比其他题解要详细?

不过这题本质上跟FWT的and卷积没啥关系…不知道某些老哥是怎么理解的…就只是单纯的长得像而已…

首先考虑,对于一个串 s 而言,直接统计比较麻烦,麻烦在难以体现「不认识」这个限制上。所以考虑如何忽略这个限制。考虑忽略限制后,就变成了统计 ans(s) 表示至少含有集合 s 的排列数。

那么对于 ans(s) 存在一个性质:如果 ss' 中,连通分量状态相同,那么两个集合的 ans 是等价的。此处连通分量指的是连续一段互相认识的人,状态相同指的是 ss' 的这些段大小相同,排布可以不同。(比如 0100111\Longleftrightarrow0011101\Longleftrightarrow1000111) 。

证明很简单,因为对于生成 01 的串而言,其方案数只在于有多少排列可以凑出这些 1 的连续段而已。换言之就是由于是全排列,所以对称。

值得注意的是,对于一个长度为 k1 连续段,其包含 k+1 个互相认识的人,也就是该连通块大小为 k+1

考虑如何求 ans(s) 。注意到现在已经转化成了统计每一种对 n 的拆分方式,有多少种排列数。那么一个比较自然的想法就是求下式

其中 t_i 表示枚举的第 i 个点集,s_i 表示组成 s 这个 01 串的第 i 个连通块(链形态的点集) ,f_{i,j} 表示从 size(i)=j 的集合 i 里面选出一个大小为 j的方案数。不难知道这些限制的意义:划分必须恰好划分掉 n 个点,且点集之间不存在交集(否则需要合并)。

上式的意义在于,对于 s 的一个划分,每个连通块都需要从某个点集中选出,而点集之间是互不相交的,所以 size 必须恰好是链长。于是可以从这个角度入手来求排列数。同时需要注意,由于我们是硬生生划定了 p 个集合,并不关心集合之间是否有连边,也就代表了其中有些单点(也就是 s 中的 0 位置)可能是与其他连通块是一体的,这也就符合了我们对 ans(s) 的定义:至少含有集合 s 的排列数。

注意到 f 是比较容易求得的。考虑 g_{i,j} 表示集合 ij 结尾的链的方案数,那么就是 n^2 转移,保证每次都用最后一个点转移就可以使得形态是一条链。f 就是对 g 的一个累加。这一部分复杂度 O(n^22^n)

考虑求出 f 之后如何计算这个式子。一个比较直接的想法是暴力枚举子集的子集来转移,由于 \Pi\Sigma 有分配律,可知转移是不难的。复杂度 P(n)3^n,其中 P(n) 是本质不同的划分数。可以过掉 n=14\rm F1

对于 \rm F2 ,考虑这么一个问题:如何保证一组相加起来 size 等于全集的子集互不相交?很显然是如果他们的就是全集,那么彼此之间一定不会有交。于是可以知道用 FMT 来优化这个过程。具体的,根据分配律,可以对每个 s_i 分开计算其贡献,进行 p 次 FMT 之后,答案就是第 2^n-1 项系数。

那么就做完了。最后只需对于每个 01 串,求出他的划分即可。注意到一开始 ans 的定义,需要我们做一次子集差变换。大概类似于FMT处理and的逆过程。最终复杂度 O(2^n(P(n)+n+n^2))

#include <bits/stdc++.h>

using namespace std ;

#define p_p pop_back
#define p_b push_back
#define vint vector<int>

typedef long long ll ;

const int N = 20 ;
const int M = 1200000 ;

int n ;
int m ;
int tot ;
ll I[M] ;
ll t[M] ;
ll ans[M] ;
ll f[M][N] ;
ll g[N][M] ;
ll res[N * N] ;
int know[N][N] ;
map<vint, int> Id ;
vint part[N * N], now ;

void fmt_or(ll *f, int g){
    int i, j ;
    for (i = 0 ; i < n ; ++ i)
        for (j = 0 ; j <= m ; ++ j)
            if (j >> i & 1) f[j] = f[j] + (ll)g * f[j ^ (1 << i)] ;
}
void dfs_part(int st, int big){
    if (st == n + 1)
        return void(part[Id[now] = ++ tot] = now) ;
    if (n - st + 1 >= big){
        for (int i = big ; i <= n - st + 1 ; ++ i)
            now.p_b(i), dfs_part(st + i, i), now.p_p() ;
    }
}
void dp1(){
    for (int i = 0 ; i < n ; ++ i) f[1 << i][i + 1] = 1 ;
    for (int i = 1 ; i <= m ; ++ i){
        int len = 0 ;
        for (int j = 1 ; j <= n ; ++ j){
            if (1 << (j - 1) & ~i) continue ;
            for (int k = 1 ; k <= n ; ++ k){
                if (!know[j][k]) continue ;
                if (1 << (k - 1) & i) continue ;
                f[i | (1 << (k - 1))][k] += f[i][j] ;
            }
            ++ len ;
        }
        for (int j = 1 ; j <= n ; ++ j)
            if (1 << (j - 1) & i) g[len][i] += f[i][j] ;
    }
}
void dp2(){
    for (int i = 1 ; i <= tot ; ++ i){
        memcpy(t, I, sizeof(I)) ;
        int o = (int)part[i].size() ;
        for (int j = 0 ; j < o ; ++ j)
            for (int k = 0 ; k <= m ; ++ k)
                t[k] *= g[part[i][j]][k] ;
        fmt_or(t, -1) ; res[i] = t[m] ;
    }
    //for (int i = 1 ; i <= tot ; ++ i) cout << res[i] << " " ; puts("") ;
}
void revalue(){
    for (int i = 0 ; i <= m ; ++ i){
        now.clear() ;
        for (int l = 0, r ; l < n ; l = r + 1){
            r = l ;
            while ((1 << r) & i && r < n)
                ++ r ; now.p_b(r - l + 1) ;
        }
        sort(now.begin(), now.end()) ;
        ans[i] = res[Id[now]] ;
    }
}
int main(){
    cin >> n ;
    m = (1 << n) - 1 ;
    I[0] = 1 ; fmt_or(I, 1) ;
    for (int i = 1 ; i <= n ; ++ i)
        for (int j = 1 ; j <= n ; ++ j)
            scanf("%1d", &know[i][j]) ; dp1() ;
    for (int i = 1 ; i <= n ; ++ i) fmt_or(g[i], 1) ;
    /*
    for (int i = 1 ; i <= n ; ++ i)
        for (int j = 1 ; j <= m ; ++ j)
            cout << g[i][j] << " \n"[j == m] ;
    */
    dfs_part(1, 1) ; dp2() ; revalue() ;
    //for (int j = 1 ; j <= m ; ++ j)
    //    cout << I[j] << " " ; puts("") ;
    ((m += 1) >>= 1) -= 1 ;
    for (int i = 0 ; i < n ; ++ i)
        for (int j = m ; j >= 0 ; -- j)
            if (~j >> i & 1){
                ans[j] = ans[j] - ans[j | (1 << i)] ;
            }
    for (int i = 0 ; i <= m ; ++ i)
        cout << ans[i] << " \n"[i == m] ;
    return 0 ;
}