题解:P13275 [NOI2025] 集合

· · 题解

场上约 1h 通过此题,退役之战的高光时刻。一直训练的计数水平确实在这题上表现出来了,只可惜这 day 2 没有给我乘胜追击的机会,翻盘失败。

考虑枚举 S,计算 f(P) = f(Q) = S 的方案数。直观的想法是进行容斥:枚举集合 U\supseteq S, V \supseteq S,钦定 f(P) 中所有在 U 中的位为 1f(Q) 中所有在 V 中的位为 1

这样一个 (U, V)S 的贡献系数为 (-1)^{\mathrm{pop(U) + pop(V) - 2pop(S)}}(U, V) 的方案数容易计算:考虑所有的 i,如果 i 同时是 U, V 的超集,那么方案数为 1 + 2a_i;如果是 U, V 其中一个的超集,方案数为 1 + a_i;否则方案数为 1。把所有 i 的方案数乘起来即可。

这样已经可以做到 \mathcal O(8 ^ n) 了。

优化的想法很自然:考虑交换枚举顺序,先枚举 U, V,这时候可能产生贡献的 S 必须是 U\ \mathrm{and}\ V 的子集。现在和 S 有关的只有容斥系数,这个数只和 \mathrm{pop}(U\ \mathrm{and}\ V) 有关。

(U, V) 的方案数可以通过研究性质 B 想到一个方法:记 f_S 表示 S 所有超集的 (1 + a_i) 的乘积,g_S 表示 S 所有超集的 (1 + 2a_i) 的乘积。显然权值为 \frac{f_Uf_V}{g_{U \ \mathrm{or}\ V}}。因此容易优化到 \mathcal O(4 ^ n)。进一步地,\mathrm{pop}(U\ \mathrm{and}\ V) = \mathrm{pop}(U) + \mathrm{pop}(V) - \mathrm{pop}(U \ \mathrm{or} \ V),因此容易用 or - FWT 优化至 \mathcal O(2 ^ nn)

最后需要考虑一下 a_i = 998244352 如何处理。考虑扩域,对每个值记录二元组 (k, v) 表示乘上了 k0,显然对于每个 S 只有最小的 k 有效。容易发现 FWT 可以进行,因为具有可加性。而 IFWT 需要减法,看似会出现问题,但是实际上容易分析出正确性,因为如果 (k_1, v_1)(k_2, v_2) 时,一定有 k_2 \ge k_1,且 k_2>k_1 时不产生贡献。因此正确性得证,问题在 \mathcal O(2 ^nn) 时间复杂度内完美解决。

复现的代码:

#include <bits/stdc++.h>
using ll = long long;
using ld = long double;
using ull = unsigned long long;
using namespace std;
template <class T>
using Ve = vector<T>;
#define ALL(v) (v).begin(), (v).end()
#define pii pair<ll, ll>
#define rep(i, a, b) for(ll i = (a); i <= (b); ++i)
#define per(i, a, b) for(ll i = (a); i >= (b); --i)
#define pb push_back
bool Mbe;
ll read() {
    ll x = 0, f = 1; char ch = getchar();
    while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
    return x * f;
}
void write(ll x) {
    if(x < 0) putchar('-'), x = -x;
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
const ll Mod = 998244353;
ll n, a[(1 << 20) + 9];
ll pw(ll x, ll p) {
    ll res = 1;
    while(p) {
        if(p & 1) res = res * x % Mod;
        x = x * x % Mod, p >>= 1;
    }
    return res;
}
ll Add(ll x, ll y) {
    return ((x += y) >= Mod) ? (x - Mod) : (x);
}
ll Sub(ll x, ll y) {
    return ((x -= y) < 0) ? (x + Mod) : (x);
}
pii operator + (const pii &a, const pii &b) {
    if(a.first ^ b.first) return min(a, b);
    return {a.first, Add(a.second, b.second)};
}
pii operator - (const pii &a, const pii &b) {
    if(a.first ^ b.first) return a;
    return {a.first, Sub(a.second, b.second)};
}
pii operator * (const pii &a, const pii &b) {
    return {a.first + b.first, a.second * b.second % Mod};
}
pii f[(1 << 20) + 5], g[(1 << 20) + 5];
ll pw2[25], ipw2[25];
void FWT(pii *f) {
    for(ll i = 1; i < (1 << n); i <<= 1) {
        for(ll j = 0; j < (1 << n); j += (i << 1)) {
            rep(k, 0, i - 1) f[j + k + i] = f[j + k + i] + f[j + k];
        }
    }
}
void IFWT(pii *f) {
    for(ll i = 1; i < (1 << n); i <<= 1) {
        for(ll j = 0; j < (1 << n); j += (i << 1)) {
            rep(k, 0, i - 1) f[j + k + i] = f[j + k + i] - f[j + k];
        }
    }
}
ll pc(ll n) {
    return __builtin_popcount(n);
}
void solve() {
    n = read();
    rep(i, 0, (1 << n) - 1) a[i] = read();
    rep(i, 0, (1 << n) - 1) {
        if(a[i] != Mod - 1) {
            f[i] = {0, (1 + a[i]) % Mod};
            g[i] = {0, (1 + a[i] + a[i]) % Mod * pw((1 + a[i]) * (1 + a[i]) % Mod, Mod - 2) % Mod};
        }
        else {
            f[i] = {1, 1};
            g[i] = {2, Mod - 1};
        }
    }
    for(ll i = 1; i < (1 << n); i <<= 1) {
        for(ll j = 0; j < (1 << n); j += (i << 1)) {
            rep(k, 0, i - 1) {
                f[j + k] = f[j + k] * f[j + k + i];
                g[j + k] = g[j + k] * g[j + k + i];
            }
        }
    }
    rep(i, 0, (1 << n) - 1) {
        if(pc(i) & 1) f[i].second = Mod - f[i].second;
        f[i].second = f[i].second * pw2[pc(i)] % Mod;
    }
    FWT(f);
    rep(i, 0, (1 << n) - 1) f[i] = f[i] * f[i];
    IFWT(f);
    ll ans = 0;
    rep(i, 0, (1 << n) - 1) {
        if(f[i].first ^ g[i].first) continue;
        ans = (ans + f[i].second * g[i].second % Mod * ipw2[pc(i)]) % Mod;
    }
    write(ans), putchar('\n');
}
bool Med;
int main() {
    cerr << fabs(&Med - &Mbe) / 1048576.0 << "MB\n";
    ll T = read(); T = read();
    pw2[0] = 1;
    rep(i, 1, 22) pw2[i] = pw2[i - 1] * 2 % Mod;
    rep(i, 0, 22) ipw2[i] = pw(pw2[i], Mod - 2);
    while(T--) solve();
    cerr << "\n" << clock() * 1.0 / CLOCKS_PER_SEC * 1000 << "ms\n";
    return 0;
}