题解 P5364 【[SNOI2017]礼物】

_rqy

2019-05-25 10:07:12

Solution

@Fading 首先我要告诉您您的做法不是最优... $O(k+\log n)$!(如果谁有更优的做法的话请告知我...) 首先,容易发现这个东西可以写成矩阵乘法的形式...(如果你没有发现,请看最早那篇题解)...然后可以惊奇的发现,矩阵是上三角矩阵,并且对角线上有一个 $2$ 和 $k+1$ 个 $1$! 因为矩阵是上三角,所以显然其特征值就是主对角线上的元素,所以根据递推通项公式之类的性质,可以发现答案一定是这种形式: $$c_k2^n+\sum_{i=0}^ka_{k,i}n^i$$ 即一个 $2^n$ 的项加上一个 $k$ 次多项式。如果我们知道了 $c_k$ 的值,那么通过预处理 $[1, k+1]$ 项就可以拉格朗日插值得到 $n$ 的值。(拉格朗日插值可以做到 $O(k)$) 众所周知 $c_k2^n$ 的差分仍然是 $c_k2^n$,但是 $k$ 次多项式的差分是 $k-1$ 阶多项式,所以把前 $k+1$ 项进行 $k$ 阶差分(这个可以直接用组合数 $O(k)$ 得到)就可以得到 $c_k$ 的值。 这样的话,复杂度瓶颈还剩下预处理前 $k+1$ 项。而可以发现只需要求出 $1^k\dots (k+1)^k$,而这可以通过欧拉筛只求其中素数的幂,从而做到 $O(\pi(k)\log k+k)=O(k)$。 于是总复杂度就是 $O(k+\log n)$。($O(\log n)$是计算 $2^n$ ~~和输入n~~)($n$ 更大也不用高精度,只需要对 $10^9+7$ 取模算插值,$10^9+6$ 取模算 $2$ 的幂即可) 代码 ```cpp // luogu-judger-enable-o2 #include <algorithm> #include <cctype> #include <cstdio> #include <cstring> typedef long long LL; const int K = 15; const int mod = 1000000007; LL pow_mod(LL a, LL b) { LL ans = 1; for (; b; b >>= 1, a = a * a % mod) if (b & 1) ans = ans * a % mod; return ans; } LL pK[K], f[K], inv[K]; int pr[K], cnt; void Sieve(int k) { pK[1] = 1; for (int i = 2; i <= k + 1; ++i) { if (!pK[i]) pK[pr[cnt++] = i] = pow_mod(i, k); for (int j = 0; pr[j] * i <= k + 1; ++j) { pK[i * pr[j]] = pK[i] * pK[pr[j]] % mod; if (i % pr[j] == 0) break; } } } int main() { int k; LL n; scanf("%lld%d", &n, &k); --n; LL s = 0, tn = n % mod; Sieve(k); for (int i = 0; i <= k; ++i) { if ((f[i] = s + pK[i + 1]) >= mod) f[i] -= mod; if ((s += f[i]) >= mod) s -= mod; } if (n <= k) return printf("%lld\n", f[n]) & 0; LL g = -1; inv[1] = 1; for (int i = 2; i <= k; ++i) { inv[i] = -(mod / i) * inv[mod % i] % mod; g = g * -inv[i] % mod; } LL c = 1, p = 0; for (int i = 0; i <= k; ++i) { // sum_{i=0}^k (-1)^(k-i) C(k, i) f[i+1] LL _t = c * f[i] % mod; c = c * inv[i + 1] % mod * (k - i) % mod; if ((k - i) & 1) p -= _t; else p += _t; } LL ans = 0, p2 = p, t = 1; for (int i = 0; i <= k; ++i) { LL _y = (f[i] - p2) * g % mod; ans = (ans * (tn - i) + _y * t) % mod; t = t * (tn - i) % mod; p2 = p2 * 2 % mod; g = g * (i - k) % mod * inv[i + 1] % mod; } ans = (ans + p * pow_mod(2, n % (mod - 1))) % mod; printf("%lld\n", (ans + mod) % mod); } ```