[数学] 容斥权值也可以「待定系数」?

· · 算法·理论

前言:笔者的数学相较于其他板块非常辣鸡,所以说如果对于同一道题你会使用正经方法推容斥系数,就别用这个下策了。但是用下策过题感觉是真好。

很多萌新面对有些题目可以想到容斥,但是死活推不出系数,然后盲代 (-1)^k 之后对的说不出所以然,错的就直接不知道怎么改了。面对这类问题,大家都知道数学不好(如笔者)是原罪,但是确实不会推导还想过容斥题怎么办呢?今天笔者 T 氏来给大家介绍一个替代品,即待定系数法打表出容斥系数。

众所周知的是数学课上教的待定系数法,即先钦定一个式子满足某某形式,但我们并不知道该形式的参数,再用解方程把这个参数求出来。这类方法的好处是无脑简单,只要形式是对的那么必定能够算出来;坏处也很明显,即这个形式需要猜,猜错了可能无解。

我们发现这个过程同样可以套用在容斥的过程中:某些题目我们看一看就知道要容斥,但是推不出系数。此时不妨假想一个 DP 框架,再根据这个框架的需求反向推导出容斥系数

具体的意思是,合理的容斥系数只需要满足对于我们想要计数的内容 x,与 x 相关的函数 f(x) = 1,且对于不合法的内容 xf(x) = 0 即可。我们已经知道 f 函数(就是 DP 的过程),那么可以反向 DP 或解方程,算出 f 中不同项的系数,即容斥系数。

我们结合例题来理解这个思想。

待定系数法的基本应用

P6846 [CEOI 2019] Amusement Park

有一个含 n 个点,m 条边的有向图,图无重边,无自环,两点之间不成环。

现在我们想改变一些边的方向,使得该有向图无环。

您需要求出,每一种改变方向后使得该有向图无环的方案的需改变边的数量之和 \bmod\ 998244353 之后的答案。

首先边的数量之和就是骗人的,算出总方案数之后乘 \frac{m}{2} 就好了。

不难想到设 dp_{S} 表示集合 S 的答案。每次转移时枚举 T 作为无环图中入度为 0 的节点集合。

很快我们意识到这样会重复计算,因为 S - T 中可能还有入度为 0 的点。

因此现在我们发现需要容斥。考虑待定系数法,钦定当前 DP 框架下一定存在一个系数 a,使得我们的转移式子中的 dp_T 改为 a_Tdp_T 后不重不漏。

那么这个系数是多少呢?这时候就可以解方程计算了。首先我们需要列出条件:

这个剩下总贡献就是 f(S - T) 啊,我们已经知道 f(S - T) 的确切值了,因此容斥系数可以直接用如下代码打表:

::::info[打表程序]

/* 省选:2026.3.7 */
/* Hatsune Miku x Kasane Teto */
#include <bits/stdc++.h>
#define lowbit(x) ((x) & (-(x)))
using namespace std;

const int N = 19, mod = 998244353;
int n = 9, a[1 << 9];
inline int f(int x) {return (x != 0);}

int main()
{
//  freopen("text.in", "r", stdin);
//  freopen("prog.out", "w", stdout);
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    for(int i = 1; i < (1 << n); ++i)
    {
        int sum = 0;
        for(int j = i; j; j = ((j - 1) & i))
            sum += f(i - j) * a[j];
        a[i] = f(i) - sum;
    }
    return 0;
}

::::

根据待定系数法的性质,只要 a_i 能算出来那一定是对的。因此直接用这个 a 数组就能写个 O(3^n) 的做法。由于 T 氏的数学太没救了所以不会用哈集幂优化。

题外话:打表出来的 a 就是 (-1)^{|S| + 1},这是一个经典结论。

::::success[AC 代码(复杂度 O(3^n))]

/* 省选:2026.3.7 */
/* Hatsune Miku x Kasane Teto */
#include <bits/stdc++.h>
#define lowbit(x) ((x) & (-(x)))
using namespace std;

const int N = 19, mod = 998244353;
int n, m, a[1 << 18]; // a 表示打表出的容斥系数 
inline int f(int x) {return (x != 0);}

struct Link {int x, y;} e[N * N];
long long dp[1 << 18]; bool g[1 << 18];
// dp[S] 表示 S 集合的答案,g[S] 存 S 是不是一个独立集 

int main()
{
//  freopen("text.in", "r", stdin);
//  freopen("prog.out", "w", stdout);
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= m; ++i) cin >> e[i].x >> e[i].y;

    // 打表容斥 
    for(int i = 1; i < (1 << n); ++i)
    {
        int sum = 0;
        for(int j = i; j; j = ((j - 1) & i))
            sum += f(i - j) * a[j];
        a[i] = f(i) - sum;
    }

    // 判定独立 
    for(int mask = 1; mask < (1 << n); ++mask)
    {
        g[mask] = true;
        for(int i = 1; i <= m; ++i)
            if(((1 << (e[i].x - 1)) & mask) && ((1 << (e[i].y - 1)) & mask))
                g[mask] = false;
    }

    // DP 求解 
    dp[0] = 1;
    for(int mask = 1; mask < (1 << n); ++mask)
    {
        for(int sub = mask; sub; sub = ((sub - 1) & mask))
            if(g[sub])
                dp[mask] = (dp[mask] + dp[mask - sub] * (a[sub] + mod)) % mod;
    }
    cout << dp[(1 << n) - 1] * m % mod * (mod + 1) / 2 % mod << '\n';
    return 0;
}

::::

P10591 BZOJ4671 异或图

同样是待定系数法的基本应用。我们发现若钦定若干个块相互不连通,则它们内部可能也不连通,从而导致重复计算。

根据转移式子写出打表程序如下:

::::info[计算系数]

long long C[N][N], w[N];
inline long long F(int x) {return (x == 1);}
inline void init()
{
    // 待定系数法求解容斥系数 
    for(int i = 0; i <= n; ++i)
    {
        C[i][0] = 1;
        for(int j = 1; j <= i; ++j)
            C[i][j] = C[i - 1][j - 1] + C[i - 1][j];
    }
    for(int i = 1; i <= n; ++i)
    {
        long long sum = 0;
        for(int j = 1; j < i; ++j)
            sum += C[i - 1][j - 1] * w[j] * F(i - j);
        w[i] = F(i) - sum;
    }
}

::::

P10104 [GDKOI2023 提高组] 异或图

跳过 m = 0 的部分(和容斥无关),如果不会可以看看其他题解。

我们现在想做的就是每次钦定一堆点,并声称这堆点的权值相同。但是也会有计算重复,还是因为不能保证其他点有没有权值和钦定点一样的。

然后我们自然而然地想到设容斥系数。此题中显然有 f(S) = [S \ 是独立集],那么按照上文的方法做就能推出容斥系数了。

顺便一提,这道题的系数貌似只能这样得到,因为图是数据给你的,那么 f(S) 就并不是一个只和 |S| 相关的式子了。这也体现了待定系数容斥的好处。

寻找规律

P7275 计树

容易想到每次钦定一条值连续的链,再用 \text{Prüfer} 的结论 ans = n^{k - 2}\prod siz_i 计算。这样一条长度为 x 的链的权值是 [x \geq 2]xn,最后除以 n^2 就是答案。

但是我们发现每次钦定的链不是极长的,因此计算重复,又需要设置容斥系数了。

仍然考虑待定系数法:

剩下部分的贡献还是 f,那不就直接做出来了?打表可得 x : 1 \to n 的容斥系数 w_x

::::info[能过 O(n^2) 部分分的代码]

/* 省选:2026.3.7 */
/* Hatsune Miku x Kasane Teto */
#include <bits/stdc++.h>
#define lowbit(x) ((x) & (-(x)))
using namespace std;

const int N = 1010, mod = 998244353;
inline long long qpow(long long a, long long b)
{
    long long res = 1;
    while(b)
    {
        if(b & 1) res = res * a % mod;
        b >>= 1, a = a * a % mod;
    }
    return res;
}
inline void Plus(long long &now, long long add)
{now += add; while(now >= mod) now -= mod;}
int n;

long long w[N], g[N], dp[N];
inline int F(int x) {return (x >= 2);}

int main()
{
//  freopen("text.in", "r", stdin);
//  freopen("prog.out", "w", stdout);
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    cin >> n;
    for(int i = 1; i <= n; ++i)
    {
        long long sum = 0;
        for(int j = 1; j < i; ++j)
            Plus(sum, w[j] % mod * F(i - j) % mod);
        w[i] = (F(i) - sum + mod) % mod;
    }

    dp[0] = 1;
    for(int i = 2; i <= n; ++i)
    {
        for(int j = 2; j <= i; ++j)
            Plus(dp[i], (1ll * j * n % mod) * w[j] % mod * dp[i - j] % mod);
    }
    cout << dp[n] * qpow(1ll * n * n % mod, mod - 2) % mod << '\n';
    return 0;
}

::::

如果你觉得不够过瘾,那么可以把 w 输出一下,我们发现它居然是周期为 6 的周期数列,那直接 O(n) 得到 w 的每一项。再看这个 DP 是半在线卷积的形式,于是分治 NTT 就做完了,复杂度 O(n \log^2 n),肯定不敌生成函数,然而不影响我们 AC。

::::success[优化后的代码]

/* 省选:2026.3.7 */
/* Hatsune Miku x Kasane Teto */
#include <bits/stdc++.h>
#define lowbit(x) ((x) & (-(x)))
using namespace std;

const int N = 100010, mod = 998244353;
inline long long qpow(long long a, long long b)
{
    long long res = 1;
    while(b)
    {
        if(b & 1) res = res * a % mod;
        b >>= 1, a = a * a % mod;
    }
    return res;
}
inline void Plus(long long &now, long long add)
{now += add; while(now >= mod) now -= mod;}
int n;

const int g = 3, gi = (mod + 1) / g;
int rev[N * 4];
struct Poly
{
    vector < long long > f;
    inline void NTT(int len, int type)
    {
        f.resize(len);
        for(int i = 0; i < len; ++i) rev[i] = rev[i / 2] / 2 + ((i & 1) ? (len / 2) : 0);
        for(int i = 0; i < len; ++i) if(i > rev[i]) swap(f[i], f[rev[i]]);
        for(int i = 1; (1 << i) <= len; ++i)
        {
            long long wn = qpow((type == 1) ? g : gi, (mod - 1) >> i);
            for(int j = 0; j < len; j += (1 << i))
            {
                long long w = 1;
                for(int k = j; k < j + (1 << (i - 1)); ++k, w = w * wn % mod)
                {
                    long long n1 = f[k], n2 = f[k + (1 << (i - 1))] * w % mod;
                    f[k] = (n1 + n2) % mod, f[k + (1 << (i - 1))] = (n1 - n2 + mod) % mod;
                }
            }
        }
        if(type == -1)
        {
            long long iv = qpow(len, mod - 2);
            for(int i = 0; i < len; ++i) f[i] = f[i] * iv % mod;
        }
    }

    inline void output() {for(long long x : f) cout << x << " "; cout << '\n';}
};
Poly operator * (Poly u, Poly v)
{
    int r_len = u.f.size() + v.f.size() - 1, len = 1;
    while(len < r_len) len <<= 1;

    Poly res; u.NTT(len, 1), v.NTT(len, 1);
    for(int i = 0; i < len; ++i) res.f.push_back(u.f[i] * v.f[i] % mod);
    res.NTT(len, -1);

    res.f.resize(r_len); res.f.shrink_to_fit(); return res;
}

long long dp[N], w[N];
inline void solve(int L, int R)
{
    // 半在线卷积 
    if(L == R) {if(!L) dp[0] = 1; return ;}
    int mid = (L + R) >> 1;
    solve(L, mid);

    Poly A, B, C;
    for(int i = L; i <= mid; ++i) A.f.push_back(dp[i]);
    for(int i = 0; i <= R - L; ++i) B.f.push_back(w[i]);
    C = A * B;
    for(int i = mid + 1; i <= R; ++i) Plus(dp[i], C.f[i - L]);

    solve(mid + 1, R);
}

int main()
{
//  freopen("text.in", "r", stdin);
//  freopen("prog.out", "w", stdout);
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    cin >> n;

    for(int i = 0; i <= n; ++i)
    {
        if(i == 0) w[i] = 0;
        else switch(i % 6)
        {
            case 1 : w[i] = 0; break;
            case 2 : w[i] = 1; break;
            case 3 : w[i] = 1; break;
            case 4 : w[i] = 0; break;
            case 5 : w[i] = mod - 1; break;
            default : w[i] = mod - 1; break;
        }
    }
    for(int i = 2; i <= n; ++i) w[i] = w[i] * i % mod * n % mod;
    solve(0, n);

    cout << dp[n] * qpow(1ll * n * n % mod, mod - 2) % mod << '\n';
    return 0;
}

::::

#3395. 「2020-2021 集训队作业」Yet Another Permutation Problem

先转化题意,把拿出来放到两端映射为从两端放回中间,那么我们稍加推到就知道如何对一个排列 O(n) 求解最小次数:即 n 减去最长连续值域上升子序列长度。

比如这个排列 1, 7, 3, 4, 5, 6, 2 的最长连续值域上升子序列是 3, 4, 5, 6,那么这个排列的答案是 2

下面称这种子序列为钻头序列。

不难想到先枚举最长钻头序列的长度 t,再在内部 DP,这样复杂度肯定不高于 O(n^3)

那么内部 DP 时,我们又遇到经典问题:钻头序列不一定是极长的,计算重复。这个好说,直接上待定系数:

解方程即可。

然后我们发现 O(n^3) 超时了。我们把 nw 序列输出,发现它们不仅是周期序列,还只有 O(n\ln n) 个位置非零。这还有啥好说的,直接 O(n^2\ln n) 做完了,时间复杂度和标准解法一致。

以上就是全部内容了。可以看出待定系数体现了容斥的本质,即“通过巧妙分配权值,使得合法的方案恰好被计算一次”。通过给出条件让电脑帮我们推系数,我们的思考过程可以被充分简化。