P6043 题解 / 组合数前缀和

· · 题解

原题链接:P6043 「ACOI2020」修学旅行。

把之前的笔记整理一下,写个题解。

我们先来证明题目中要用到的结论(n\ge 2m):

\sum_{k=0}^m\binom mk^2\binom{n+k}{2m}=\binom nm^2

有一个叫 Zeilberger 的神奇算法,可以证明这条式子。但是我不会,所以只好用生成函数硬推了。

我们先证一个引理(n\ge m):

\sum_{k=0}^m(-1)^{m-k}\binom mk\binom {n+k}n=\binom nm

这个引理的证明较简单。有两种方法,一种是纯生成函数推导,这里采用另一种方法:二项式反演。

g_k=\binom {n+k}n,f_k=\binom nk,相当于我们要证明:

\sum_{k=0}^m(-1)^{m-k}\binom mkg_k=f_m

根据二项式反演的结论,我们只需要证明:

\sum_{k=0}^m\binom mkf_k=g_m

代入 f_k 定义后我们发现左式是一个卷积的形式,于是:

\sum_{k=0}^m\binom {m}{m-k}\binom{n}{k}=[x^m]\Big((1+x)^m*(1+x)^n\Big)=[x^m](1+x)^{m+n}=g_m

这样引理就得证了。

接着有以下两个基础的式子:

\binom mk=[x^m]\dfrac{x^k}{(1-x)^{k+1}} \binom {n+k}{2m}=[z^{2m}]\dfrac{1}{(1-z)^{n-2m+k+1}}

它们是基于以下这个等式得出来的:

\sum_{k}\binom{m+k}nx^k=\frac{x^{n-m}}{(1-x)^{n+1}}

数学归纳法可证,此处略过。

然后把这两个式子代入到左式,然后消掉一个元简化式子:

\begin{aligned} \sum_{k=0}^m\binom mk^2\binom{n+k}{2m}&=\sum_{k}[x^my^mz^{2m}]\dfrac{x^k}{(1-x)^{k+1}}\dfrac{y^k}{(1-y)^{k+1}}\dfrac{1}{z^{n-2m+k+1}}\\ &=[x^my^mz^{2m}]\dfrac{1}{(1-x)(1-y)(1-z)^{n-2m+1}}\dfrac{1}{1-\frac{xy}{(1-x)(1-y)(1-z)}}\\ &=[x^mz^{2m}]\dfrac{1}{(1-z)^{n-2m}}[y^m]\dfrac{1}{(z-xz-1)y+(x-1)(z-1)}\\ &=[x^mz^{2m}]\dfrac 1{(1-z)^{n-2m}}\dfrac{(1+xz-z)^m}{(1-x)^{m+1}(1-z)^{m+1}} \end{aligned}

最后一个等号是用到了 [x^n](ax+b)^{-1}=(-a)^nb^{-n-1} 这条等式,同样可归纳证明。

下一步,我们要把这个式子再展开,并化简成引理的形式:

\begin{aligned} \sum_{k=0}^m\binom mk^2\binom{n+k}{2m}&=[x^mz^{2m}]\sum_{k=0}^m\binom mk\dfrac{z^k(x-1)^k}{(1-x)^{m+1}(1-z)^{n-m+1}}\\ &=\sum_{k=0}^m(-1)^k\binom mk[x^m]\dfrac{1}{(1-x)^{m-k+1}}[z^{2m-k}]\dfrac{1}{(1-z)^{n-m+1}}\\ &=\sum_{k=0}^m(-1)^k\binom mk\binom{2m-k}{m}\binom{n+m-k}{n-m}\\ &=\binom nm\sum_{k=0}^m(-1)^k\binom mk\binom{n+m-k}{n}\\ &=\binom nm\sum_{k=0}^m(-1)^{m-k}\binom mk\binom{n+k}{n}\\ &=\binom nm^2 \end{aligned}

这样就证完了!结合这条式子就可以将「快乐度」的表达式化简为组合数前缀和了。

顺带一提,这个思路可以证明一个更普遍的结论:

\sum_k\binom nk\binom mk\binom {\alpha+k}{n+m}=\binom \alpha n\binom \alpha m

不过,这个证明有个小缺点——整体思路不够清晰,怎么这样推着推着就求出来了?我也不知道该怎么解释,因为我本来就是乱推的,只是一不小心就给证出来了,这也许就是瞎猫碰见死老鼠罢

接下来看这道题的第二部分,即快速求组合数前缀和。(注:以下的做法参考了 OI WIKI 的内容)

这个求法的关键是这条式子:

\begin{bmatrix}\binom{n}{m+1}_ {_ {_ {}}}\\\sum_{i=0}^m\binom ni\end{bmatrix}=\dfrac{1}{(m+1)!}\Bigg(\prod_{i=0}^m\,\begin{bmatrix}n-i&0\\i+1&i+1\end{bmatrix}\Bigg)\begin{bmatrix}1\\0\end{bmatrix}

要注意这里矩阵乘法的顺序,把累乘展开后,在最左边的应该是 i=m 的矩阵,最右边的是 i=0 的矩阵。

然后就可以按一般的方法来维护多项式点值了,先设:

M_d(x)=\prod_{i=1}^d\begin{bmatrix}-x+n-i+1&0\\x+i&x+i\end{bmatrix}=\begin{bmatrix}f_d(x)&0\\g_d(x)&h_d(x)\end{bmatrix}

那么 d\rightarrow 2d 的方法就是:

f_{2d}(x)&=f_d(x+d)f_d(x)\\ g_{2d}(x)&=f_d(x)g_d(x+d)+g_d(x)h_d(x+d)\\ h_{2d}(x)&=h_d(x+d)h_d(x) \end{aligned}

从这篇博客中,我们得知一个偷懒的小技巧,设 v=2^{\lceil \log_2\sqrt m\rceil},对于 M_d(x),我们维护 M_d(0),M_d(v),M_d(2v)...M_d(dv) 这些点值。因为 v2 的幂,所以可以绕开 d\rightarrow d+1 的步骤,这样就不用写拉插辣!不过这样会让常数稍稍变大,常数大慎用。

至于最外面那个 \frac 1{(m+1)!},注意到 h_d(x)=\frac{(x+d)!}{x!},所以 (v^2)!=h_v(0)h_v(v)...h_v(v^2-v)v^2m 附近,可以借助 (v^2)! 求出 (m+1)!

总的时间复杂度是 O(T\sqrt m\log m),以下是核心代码:

Poly T1, T2, T3;
int F[Mx], G[Mx], H[Mx];

void Solve(int d, const int B, const int N){
    if(d == 1){
        F[0] = N, F[1] = N - B;
        G[0] = H[0] = 1;
        G[1] = H[1] = B + 1;
        return;
    }
    Solve(d >> 1, B, N), d >>= 1;

    for(int i = 0; i <= d; ++i) T1.F[i] = F[i], T2.F[i] = G[i], T3.F[i] = H[i];
    T1.DotShift(d, d + 1), T2.DotShift(d, d + 1), T3.DotShift(d, d + 1);
    for(int i = (d << 1); i > d; --i) F[i] = T1.F[i - d - 1], G[i] = T2.F[i - d - 1], H[i] = T3.F[i - d - 1];

    int K = d * Inv(B) % Mod; d <<= 1;
    for(int i = 0; i <= d; ++i) T1.F[i] = F[i], T2.F[i] = G[i], T3.F[i] = H[i];
    T1.DotShift(d, K), T2.DotShift(d, K), T3.DotShift(d, K);

    for(int i = 0; i <= d; ++i){
        G[i] = (1ll * T2.F[i] * F[i] + 1ll * T3.F[i] * G[i]) % Mod;
        F[i] = 1ll * T1.F[i] * F[i] % Mod;
        H[i] = 1ll * T3.F[i] * H[i] % Mod;
    }
}

int BinomSum(int n, int m){ // sum[0 <= k <= m] C(n, k)
    if(n == m) return FastPow(2, n);
    if(m > (n >> 1)) return (FastPow(2, n) - BinomSum(n, n - m - 1)) % Mod;
    int v = ceil(log(sqrt(m)) / log(2)), B = 1 << v;
    Solve(B, B, n);
    int K = m / B, M = K * B;

    /* Matrix Multiply */
    int m1 = F[0], m2 = G[0], m3 = H[0];
    for(int i = 1; i < K; ++i){
        m2 = (1ll * m1 * G[i] + 1ll * m2 * H[i]) % Mod;
        m1 = 1ll * m1 * F[i] % Mod;
        m3 = 1ll * m3 * H[i] % Mod;
    }
    if(!K) m1 = m3 = 1, m2 = 0;

    m3 = Inv(m3), m2 = 1ll * m2 * m3 % Mod;
    int Ans = m2;
    for(int i = M; i <= m; ++i){
        Ans = (Ans + 1ll * m1 * m3) % Mod;
        m3 = m3 * Inv(i + 1) % Mod;
        m1 = 1ll * m1 * (n - i) % Mod;
    }

    return Ans;
}

双倍经验:P5388 [Cnoi2019] 最终幻想。