载谭 Binomial Sum 小练习

· · 题解

不妨令 F_0=1,那么 F_i = F_{i-1}+F_{i-2} 就对所有 i\ge 1 满足。
然后我们考虑求出

F(z) = \frac1{1-z-z^2} \bmod z^{n+1} = \sum\limits_{i=0}^n F_i z^i

那么其显然要在 n+1,n+2 处去除贡献。即

\begin{aligned} F(z) &= zF(z) + z^2F(z) + 1 - (F_{n-1}+F_n)z^{n+1} - F_nz^{n+2} \\ F(z) &= \frac{1 - (F_{n-1}+F_n)z^{n+1} - F_nz^{n+2}}{1-z-z^2} \end{aligned}

则答案就是

\begin{aligned} \sum\limits_{i=0}^n F_i i^k &= \left[\frac{z^k}{k!}\right] \sum\limits_{i=0}^n F_i \mathrm e^{iz} \\ &= \left[\frac{z^k}{k!}\right] F(\mathrm e^z) \end{aligned}

\mathscr F(z + 1) = F(z+1) \bmod z^{k+1}
P(z) = 1 - (F_{n-1}+F_n)z^{n+1} - F_nz^{n+2},则

\begin{aligned} F(z+1) &= -\frac{P(z+1)}{1+3z+z^2} \\ \left([z^i] F(z+1)\right) &= -\left([z^i] P(z+1)\right) - 3\left([z^{i-1}] F(z+1)\right) - \left([z^{i-2}] F(z+1)\right) \end{aligned}

则令 \mathscr P(z+1) = P(z+1) \bmod z^{k+1},先求 \mathscr P(z),主要是要考虑

\mathscr G_m(z+1) = (1+z)^m \bmod z^{k+1}

考虑

\begin{aligned} (1+z)[(1+z)^m]' &= m(1+z)^m \\ (1+z)[\mathscr G_m(z+1)]' &= m\mathscr G_m(z+1) - (m-k) \binom mk z^k \\ z\mathscr G_m'(z) &= m\mathscr G_m(z) - (m-k) \binom mk (z-1)^k \end{aligned}

从而可以求出 \mathscr P(z)

\begin{aligned} &\quad\;\left([z^i] \mathscr F(z+1)\right) \\ &= -\left([z^i] \mathscr P(z+1)\right) - 3\left([z^{i-1}] \mathscr F(z+1)\right) - \left([z^{i-2}] \mathscr F(z+1)\right) \\ &+ [i=k+1]\left[3\left([z^k] \mathscr F(z+1)\right) + \left([z^{k-1}] \mathscr F(z+1)\right)\right] \\ &+ [i=k+2]\left([z^k] \mathscr F(z+1)\right) \end{aligned}

不妨令 C_1 = \left[3\left([z^k] \mathscr F(z+1)\right) + \left([z^{k-1}] \mathscr F(z+1)\right)\right]C_2 = \left([z^k] \mathscr F(z+1)\right),则

\mathscr F(z+1) = \frac{-\mathscr P(z+1)+C_1 z^{k+1}+C_2 z^{k+2}}{1+3z+z^2}

\mathscr F(z) = \frac{\mathscr P(z) - C_1(z-1)^{k+1} - C_2(z-1)^{k+2}}{1-z-z^2}

根据 EI 的结果,[z^k] F(\mathrm e^z) = [z^k] \mathscr F(\mathrm e^z)。线性求出 \mathscr F(z) 即可。
时间复杂度 \Theta(k + \log n)

代码:

#include <cstdio>
#include <algorithm>
using namespace std;
const int K = 42;
const int mod = 1e9 + 7;
inline int norm(int x)
{
    return x >= mod ? x - mod : x;
}
inline int reduce(int x)
{
    return x < 0 ? x + mod : x;
}
inline int neg(int x)
{
    return x ? mod - x : 0;
}
inline void add(int &x,int y)
{
    if((x += y - mod) < 0)
        x += mod;
}
inline void sub(int &x,int y)
{
    if((x -= y) < 0)
        x += mod;
}
inline void fam(int &x,int y,int z)
{
    x = (x + (long long)y * z) % mod;
}
inline int fpow(int a,int b)
{
    int ret = 1;
    for(;b;b >>= 1)
        (b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
    return ret;
}
long long n;
int k;
int F_n_1,F_n;
inline int bostanMori(long long n,int p0,int p1,int q0,int q1,int q2)
{
    for(;n;n >>= 1)
    {
        if(n & 1)
            p0 = ((long long)q0 * p1 + (long long)(mod - q1) * p0) % mod,
            p1 = (long long)q2 * p1 % mod;
        else
            p1 = ((long long)q2 * p0 + (long long)(mod - q1) * p1) % mod,
            p0 = (long long)q0 * p0 % mod;
        q1 = (2LL * q0 * q2 + (long long)(mod - q1) * q1) % mod,
        q0 = (long long)q0 * q0 % mod,
        q2 = (long long)q2 * q2 % mod;
    }
    return q0 == 1 ? p0 : (long long)p0 * fpow(q0,mod - 2) % mod;
}
int vis[K + 5],cnt,prime[K + 5],pw[K + 5];
int fac[K + 5],ifac[K + 5],inv[K + 5];
int prod[K + 5],iprod[K + 5],inv_n_2[K + 5];
int C_n_1_k,C_n_2_k;
int C_1,C_2;
int G_n_1[K + 5],G_n_2[K + 5];
int P[K + 5],F[K + 5];
int ans;
inline int C(int n,int m)
{
    return n < m || m < 0 ? 0 : (long long)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int main()
{
    scanf("%lld%d",&n,&k);
    pw[1] = 1;
    for(int i = 2;i <= k;++i)
    {
        if(!vis[i])
            pw[prime[++cnt] = i] = fpow(i,k);
        for(int j = 1;j <= cnt && i * prime[j] <= k;++j)
        {
            vis[i * prime[j]] = 1,pw[i * prime[j]] = (long long)pw[i] * pw[prime[j]] % mod;
            if(!(i % prime[j]))
                break;
        }
    }
    if(n <= k)
    {
        F[0] = 1;
        for(int i = 0;i <= n;++i)
            i && (add(F[i],F[i - 1]),1),
            i > 1 && (add(F[i],F[i - 2]),1),
            fam(ans,F[i],pw[i]);
        printf("%d\n",ans);
        return 0;
    }
    F_n_1 = bostanMori(n - 1,1,0,1,mod - 1,mod - 1),
    F_n = bostanMori(n,1,0,1,mod - 1,mod - 1),
    add(F_n_1,F_n);
    n %= mod;
    prod[0] = norm(n + 2);
    for(int i = 1;i <= k + 1;++i)
        prod[i] = prod[i - 1] * (n + 2 - i + mod) % mod;
    iprod[k + 1] = fpow(prod[k + 1],mod - 2);
    for(int i = k + 1;i;--i)
        iprod[i - 1] = iprod[i] * (n + 2 - i + mod) % mod;
    for(int i = 1;i <= k + 1;++i)
        inv_n_2[i] = (long long)iprod[i] * prod[i - 1] % mod;
    inv_n_2[0] = iprod[0];
    fac[0] = 1;
    for(int i = 1;i <= k + 2;++i)
        fac[i] = (long long)fac[i - 1] * i % mod;
    ifac[k + 2] = fpow(fac[k + 2],mod - 2);
    for(int i = k + 2;i;--i)
        ifac[i - 1] = (long long)ifac[i] * i % mod;
    for(int i = 1;i <= k + 1;++i)
        inv[i] = (long long)ifac[i] * fac[i - 1] % mod;
    C_n_1_k = C_n_2_k = 1;
    C_2 = norm(F_n_1 + F_n),sub(C_2,1);
    for(int i = 1,C_t;i <= k;++i)
        C_n_1_k = C_n_1_k * (n - i + 2 + mod) % mod * inv[i] % mod,
        C_n_2_k = C_n_2_k * (n - i + 3 + mod) % mod * inv[i] % mod,
        C_t = ((long long)F_n_1 * C_n_1_k + (long long)F_n * C_n_2_k - 3LL * C_2 - C_1 + 4LL * mod) % mod,
        C_1 = C_2,C_2 = C_t;
    fam(C_1,3,C_2);
    C_n_1_k = C_n_1_k * (n - k + 1 + mod) % mod,
    C_n_2_k = C_n_2_k * (n - k + 2 + mod) % mod;
    for(int i = 0;i <= k;++i)
        G_n_1[i] = (long long)(k - i & 1 ? mod - C_n_1_k : C_n_1_k) * C(k,i) % mod * inv_n_2[i + 1] % mod,
        G_n_2[i] = (long long)(k - i & 1 ? mod - C_n_2_k : C_n_2_k) * C(k,i) % mod * inv_n_2[i] % mod,
        P[i] = ((long long)F_n_1 * G_n_1[i] + (long long)F_n * G_n_2[i]) % mod,
        P[i] = neg(P[i]);
    add(P[0],1);
    for(int i = 0;i <= k;++i)
        F[i] = ((long long)(k - i & 1 ? mod - C_1 : C_1) * C(k + 1,i) + (long long)(k - i & 1 ? C_2 : mod - C_2) * C(k + 2,i)) % mod,
        add(F[i],P[i]),
        i && (add(F[i],F[i - 1]),1),
        i > 1 && (add(F[i],F[i - 2]),1),
        fam(ans,F[i],pw[i]);
    printf("%d\n",ans);
}