题解 CF1747E 【List Generation】

· · 题解

Update on 2022.11.9:感谢 @ClHg2。

前置芝士:二项式反演

考虑不合法序列的那个条件 a_i + b_i = a_{i - 1} + b_{i - 1},即 a_i = a_{i - 1}, b_i = b_{i - 1}

于是我们可以得到合法序列长度 \leq n + m + 1——因为每次 a, b 中至少减少一个。

然后发现每存在一个位置不合法,我们就可以把不合法的位置删掉——反正重复了,删了也不影响剩下的部分。当我们把所有不合法位置删完后,你会发现它变成了一个合法序列。

f(x) 表示长度恰好x合法序列的方案数,g(x) 表示长度恰好x不一定合法序列的方案数。

运用二项式反演,由 g(x) = \displaystyle\sum_{i = 2}^x C_{x - 1}^{i - 1} f(i) 可知 f(x) = \displaystyle\sum_{i = 2}^x (-1)^{x - i} C_{x - 1}^{i - 1} g(i)。答案即为 \displaystyle\sum_{i = 2}^{n + m + 1} i f(i)

现在考虑 g(x) 怎么算。注意到其实我们可以任取一个长度为 x 的满足不降限制的 a, b 序列,于是 g(x) = C_{n + x - 2}^n C_{m + x - 2}^m。代入可知:

原式 = \displaystyle\sum_{i = 2}^{n + m + 1} i \sum_{j = 2}^i (-1)^{i - j} C_{i - 1}^{j - 1} C_{n + j - 2}^n C_{m + j - 2}^m

= \displaystyle\sum_{j = 2}^{n + m + 1} (-1)^j C_{n + j - 2}^n C_{m + j - 2}^m \sum_{i = j}^{n + m + 1} (-1)^i i C_{i - 1}^{j - 1} = \displaystyle\sum_{j = 2}^{n + m + 1} (-1)^j j C_{n + j - 2}^n C_{m + j - 2}^m \sum_{i = j}^{n + m + 1} (-1)^i C_i^j

现在唯一的问题是后半部分。我们考虑按照 n + m + 1 \to 2 的顺序计算每个 j 的贡献。

我们显然要根据 (-1)^i 的奇偶性分开讨论。这里设 even, odd 分别表示当前偶数 / 奇数的 i 的贡献(这里不考虑符号)。

j + 1j 时,我们对于 C_i^{j + 1} \to C_i^j 的过程,注意到组合数递推式 C_{i + 1}^{j + 1} = C_i^j + C_i^{j + 1},于是我们:

但是对于 C_{n + m + 1}^j,上面我们会少加上一个 C_{n + m + 2}^{j + 1}……这个特判一下加给 even 还是 odd 即可。

时间复杂度为 O(N + \sum(n + m))

代码(要贺的话记得用 C++20 提交):

#include <stdio.h>

const int N = 1.5e7, mod = 1e9 + 7;
int fac[N + 7], inv_fac[N + 7];

inline int quick_pow(int x, int p){
    int ans = 1;
    while (p){
        if (p & 1) ans = 1ll * ans * x % mod;
        x = 1ll * x * x % mod;
        p >>= 1;
    }
    return ans;
}

inline void init(){
    fac[0] = 1;
    for (register int i = 1; i <= N; i++){
        fac[i] = 1ll * fac[i - 1] * i % mod;
    }
    inv_fac[N] = quick_pow(fac[N], mod - 2);
    for (register int i = N - 1; i >= 0; i--){
        inv_fac[i] = 1ll * inv_fac[i + 1] * (i + 1) % mod;
    }
}

inline int sub1(int x, int y){
    return x - y < 0 ? x - y + mod : x - y;
}

inline int comb(int n, int m){
    if (n < 0 || m < 0 || n < m) return 0;
    return 1ll * fac[n] * inv_fac[m] % mod * inv_fac[n - m] % mod;
}

inline void add(int &x, int y){
    if ((x += y) >= mod) x -= mod;
}

inline void sub2(int &x, int y){
    if ((x -= y) < 0) x += mod;
}

int main(){
    int t;
    scanf("%d", &t);
    init();
    for (register int i = 1; i <= t; i++){
        int n, m, even = 0, odd = 0, t, ans = 0;
        scanf("%d %d", &n, &m);
        t = n + m + 1;
        for (register int j = t; j >= 2; j--){
            int x = sub1(odd, even), y = mod - x;
            if (t % 2 == 0){
                add(x, comb(t + 1, j + 1));
            } else {
                add(y, comb(t + 1, j + 1));
            }
            even = x;
            odd = y;
            if (j % 2 == 0){
                add(ans, 1ll * j * sub1(even, odd) % mod * comb(n + j - 2, n) % mod * comb(m + j - 2, m) % mod);
            } else {
                sub2(ans, 1ll * j * sub1(even, odd) % mod * comb(n + j - 2, n) % mod * comb(m + j - 2, m) % mod);
            }
        }
        printf("%d\n", ans);
    }
    return 0;
}