CZOI-R3 星光闪耀

· · 题解

\gdef\mat#1#2{\begin{pmatrix}#1\\#2\end{pmatrix}}

部分分:

正解:巧妙的推式子。

这里先假设 k\ge2

一个大小为 v 的星团,会生出大小为 1\sim v-1 之间所有整数的星团。按照题目的要求,我们最终要求所有星团的闪耀度之和,而一个星团的闪耀度是 k^v。想一想,一次操作后会对星团的闪耀度之和有什么影响?

k^v\rightarrow\sum_{i=1}^{v-1}k^i

右边是等比数列,如何求和?

\sum_{i=1}^{v-1}k^i=(\sum_{i=0}^{v-1}k^i)-1=\frac{k^v-1}{k-1}-1

如果 u_{a,i} 表示经过 a 次操作后大小为 i 的星团的数量,那么闪耀度之和 S_a 可以看成什么?

S_a=\sum_{i=1}^nu_{a,i}\times k^i

如何转移到 S_{a+1}

\begin{aligned} S_{a+1} &=\sum_{i=1}^nu_{a,i}\times (\frac{k^i-1}{k-1}-1)\\ &=\frac1{k-1}\times((\sum_{i=1}^nu_{a,i}\times k^i)-(\sum_{i=1}^nu_{a,i}))-(\sum_{i=1}^nu_{a,i})\\ &=\frac{S_a-T_a}{k-1}-T_a \end{aligned}

上面的 T_a=\sum_{i=1}^nu_{a,i} 表示的就是星团个数。

$u_{a,i}$ 的递推规则如下: $$ u_{a,i}=\sum_{j=i}^nu_{a-1,i} $$ 如何将其简化? 当 $a>0$ 时: $$ u_{a,i}=\begin{cases}1&i=n\\u_{a-1,i}+u_{a,i+1}&1\le i<n\end{cases} $$ 当 $a=0$ 时: $$ u_{0,i}=\begin{cases}1&i=n\\0&1\le i<n\end{cases} $$ 实际上,当 $a=1$ 时,$u_{1,i}=1$。所以接下来的这一部分看着很像杨辉三角? 这样,当 $a\ge1$ 时,就可以进一步将递推公式简化成通项公式: $$ \begin{aligned} u_{a,i} &=\mat{a+n-i-1}{a-1}\\ &=\mat{a+n-i-2}{a-1}+\mat{a+n-i-2}{a-2} \end{aligned} $$ $$ \begin{aligned} u_{a+1,i} &=\mat{a+n-i-1}a+\mat{a+n-i-1}{a-1}\\ &=\mat{a+n-i-1}a+u_{a,i} \end{aligned} $$ 所以: $$ T_{a+1}=T_a+\sum_{i=1}^n\mat{a+n-i-1}a $$ 接下来处理后半部分。 $$ \sum_{i=1}^n\mat{a+n-i-1}a=\sum_{j=0}^{n-2}\mat{a+j}a $$ 由于 $\sum_{k=0}^m\mat{r+k}r=\mat{r+m+1}{r+1}$,因此: $$ \sum_{j=0}^{n-2}\mat{a+j}a=\mat{a+n-1}{a+1} $$ 而: $$ \begin{aligned} T_a &=\sum_{i=1}^n\mat{a+n-i-1}{a-1}\\ &=\mat{a+n-1}{a} \end{aligned} $$ 所以: $$ \begin{aligned} T_{a+1} &=T_a+\mat{a+n-1}{a+1}\\ &=\frac{a+n}{a+1}\times T_a \end{aligned} $$ 如果赛时不会推式子,也可以打表,观察相邻两个 $T_a$ 的商,可以得出同样的结论。 --- $T_a$ 解决了,$S_a$ 也解决了,那么整条思路就打通了。 我们已知 $S_0=k^n$,$T_0=1$,接下来只要根据公式 $S_{a+1}=\frac{S_a-T_a}{k-1}-T_a$ 和 $T_{a+1}=\frac{a+n}{a+1}\times T_a$ 转移到 $a=m$ 即可。预处理每个数的逆元,最终时间复杂度 $O(\sum m)$。 --- ```cpp #include <bits/stdc++.h> using namespace std; const long long MOD = 998244353; int T, N, M; long long K, I[2000006], F[2000006]; long long power(long long x, long long y) { if(y == 0) return 1; if(y == 1) return x; long long w = power(x * x % MOD, y >> 1); if(y & 1) return w * x % MOD; else return w; } int main() { F[0] = 1; int U = 2000005; for(int i = 1; i <= U; i ++) F[i] = F[i - 1] * i % MOD; I[U] = power(F[U], MOD - 2); for(int i = U - 1; i; i --) I[i] = I[i + 1] * (i + 1) % MOD; for(int i = 1; i <= U; i ++) I[i] = I[i] * F[i - 1] % MOD; scanf("%d", &T); while(T --) { scanf("%d%d%lld", &N, &M, &K); if(K == 0) { puts("0"); continue; } long long w = power(K, N), inv = power(K - 1, MOD - 2), C = 1; for(int i = 0; i < M; i ++) { if(K > 1) w = ((w + (w - C) * inv - C) % MOD + MOD) % MOD; C = C * (N + i) % MOD * I[i + 1] % MOD; } printf("%lld\n", K == 1 ? C : w); } return 0; } ```