P6667 [清华集训2016] 如何优雅地求和 题解

· · 题解

P6667 [清华集训2016] 如何优雅地求和

感觉组合数问题那道题就是这道题把多项式去掉之后的玩意。

题意

给出一个 m 次多项式函数 f(x)n,x 求下列式子的值:

\sum_{k=0}^nf(k)\dbinom{n}{k}x^k(1-x)^{n-k}\bmod{998,244,353}

其中 f(x) 以点值形式给出,给出 a_0,a_1,\cdots,a_m,满足 f(i)=a_i。(1\le n\le 10^9,1\le m\le 2\times10^4)

题解

注意到组合数是能拆成下降幂和阶乘的形式的,而这个形式对于能改变为阶乘形式的下降幂尤其友好,而对于普通幂却十分不便。所以我们考虑把 f 看做一个下降幂多项式函数,这样原式子就可以化为:

\begin{aligned}&\sum_{k=0}^n\sum_{i=0}^mb_ik^{\underline{i}}\dbinom{n}{k}x^k(1-x)^{n-k}\\=&\sum_{i=0}^mb_i\sum_{k=0}^n\dfrac{k!}{(k-i)!}\dfrac{n!}{(n-k)!k!}x^k(1-x)^{n-k}\\=&\sum_{i=0}^mb_i\sum_{k=0}^n\dfrac{(n-i)^{\underline{k-i}}}{(k-i)!(n-k)!}n^{\underline{i}}x^k(1-x)^{n-k}\\=&\sum_{i=0}^mb_in^{\underline{i}}\sum_{k=0}^n\dbinom{n-i}{k-i}x^k(1-x)^{n-k}\end{aligned}

到这一步之后,我们发现问题集中在了 k 相关的那个和式,考虑枚举 k-i,并舍弃掉没用的负数:

\sum_{i=0}^mb_in^{\underline{i}}x^i\sum_{k=0}^{n-i}\dbinom{n-i}{k}x^{k}(1-x)^{n-i-k}

而后面很明显地出现了二项式定理的形式,所以原式为:

\sum_{i=0}^mb_in^{\underline{i}}x^i(x+1-x)^{n-i}=\sum_{i=0}^mb_in^{\underline{i}}x^i

现在这个式子就非常舒服了,只要我们求出 b_i,即 f 这个下降幂多项式函数的系数,就能做到 \mathcal{O}(m) 求解答案了。

想知道怎么把下降幂多项式函数的系数找出来,我们可以考虑把这个下降幂转化为普通幂,这样我们也就知道了系数的 \rm OGF。刚刚说了,下降幂和阶乘很搭,所以考虑下降幂的 \rm EGF

\begin{aligned}\hat{U}_k(z)&=\sum_{n\ge k}\dfrac{n^{\underline{k}}z^n}{n!}\\&=\sum_{n\ge k}\dfrac{\frac{n!}{(n-k)!}z^n}{n!}\\&=\sum_{n\ge k}\dfrac{z^n}{(n-k)!}\\&=z^k\sum_{n\ge 0}\dfrac{z^n}{n!}=z^ke^z\end{aligned}

然后我们回到题目里给出的点值,考虑它的 \rm EGF

\begin{aligned}\sum_{n\ge 0}\dfrac{f(n)z^n}{n!}&=\sum_{n\ge 0}\dfrac{z^n}{n!}\sum_{j=0}^mf_jn^{\underline{j}}\\&=\sum_{j=0}^mf_j\sum_{n\ge 0}\dfrac{n^{\underline{j}}z^n}{n!}\\&=\sum_{j=0}^mf_j\hat{U}_j(z)\\&=e^z\sum_{j=0}f_jz^j\end{aligned}

也就是说,如果我们设点值的 \rm EGF\hat{A}(z),系数的 \rm OGFB(z),则它们满足:

\hat{A}(z)=e^zB(z)

即:

B(z)=e^{-z}\hat{A}(z) ```cpp #include <cstdio> #include <algorithm> const int N = 1e6 + 10, mod = 998244353; typedef long long ll; inline int ksm(int a, int b) { int ret = 1; while (b) { if (b & 1) ret = (ll)ret * a % mod; a = (ll)a * a % mod; b >>= 1; } return ret; } int rev[N], lim, m, n; inline void init(int n) { lim = 1; m = 0; while (lim <= n) lim <<= 1, ++m; for (int i = 0; i < lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (m - 1)); } inline void NTT(int* f, int len, int on) { for (int i = 0; i < len; ++i) if (i < rev[i]) std::swap(f[i], f[rev[i]]); for (int h = 2; h <= len; h <<= 1) { int gn = ksm(3, (ll)(mod - 1) / h * on % (mod - 1)); for (int j = 0; j < len; j += h) for (int k = j, g = 1; k < j + h / 2; ++k, g = (ll)g * gn % mod) { int u = f[k], t = (ll)g * f[k + h / 2] % mod; f[k] = (u + t) % mod; f[k + h / 2] = ((u - t) % mod + mod) % mod; } } if (on == mod - 2) for (int i = 0, inv = ksm(len, on); i < len; ++i) f[i] = (ll)f[i] * inv % mod; } int fac[N], ifac[N], F[N], G[N]; int main() { int n, m, x; scanf("%d%d%d", &n, &m, &x); fac[0] = ifac[0] = 1; for (int i = 1; i <= m; ++i) fac[i] = (ll)fac[i - 1] * i % mod; ifac[m] = ksm(fac[m], mod - 2); for (int i = m - 1; i >= 1; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % mod; for (int i = 0; i <= m; ++i) { scanf("%d", &F[i]); F[i] = (ll)F[i] * ifac[i] % mod; G[i] = ((i & 1) ? mod - ifac[i] : ifac[i]); } init(m << 1); NTT(F, lim, 1); NTT(G, lim, 1); for (int i = 0; i < lim; ++i) F[i] = (ll)F[i] * G[i] % mod; NTT(F, lim, mod - 2); int ni = 1, xi = 1, ans = 0; for (int i = 0; i <= m; xi = (ll)xi * x % mod, ni = (ll)ni * (n - i) % mod, ++i) (ans += (ll)F[i] * ni % mod * xi % mod) %= mod; printf("%d\n", ans); return 0; } ```