CZOI-R3 星光闪耀
hyk2019
·
·
题解
\gdef\mat#1#2{\begin{pmatrix}#1\\#2\end{pmatrix}}
部分分:
正解:巧妙的推式子。
这里先假设 k\ge2。
一个大小为 v 的星团,会生出大小为 1\sim v-1 之间所有整数的星团。按照题目的要求,我们最终要求所有星团的闪耀度之和,而一个星团的闪耀度是 k^v。想一想,一次操作后会对星团的闪耀度之和有什么影响?
k^v\rightarrow\sum_{i=1}^{v-1}k^i
右边是等比数列,如何求和?
\sum_{i=1}^{v-1}k^i=(\sum_{i=0}^{v-1}k^i)-1=\frac{k^v-1}{k-1}-1
如果 u_{a,i} 表示经过 a 次操作后大小为 i 的星团的数量,那么闪耀度之和 S_a 可以看成什么?
S_a=\sum_{i=1}^nu_{a,i}\times k^i
如何转移到 S_{a+1}?
\begin{aligned}
S_{a+1}
&=\sum_{i=1}^nu_{a,i}\times (\frac{k^i-1}{k-1}-1)\\
&=\frac1{k-1}\times((\sum_{i=1}^nu_{a,i}\times k^i)-(\sum_{i=1}^nu_{a,i}))-(\sum_{i=1}^nu_{a,i})\\
&=\frac{S_a-T_a}{k-1}-T_a
\end{aligned}
上面的 T_a=\sum_{i=1}^nu_{a,i} 表示的就是星团个数。
$u_{a,i}$ 的递推规则如下:
$$
u_{a,i}=\sum_{j=i}^nu_{a-1,i}
$$
如何将其简化?
当 $a>0$ 时:
$$
u_{a,i}=\begin{cases}1&i=n\\u_{a-1,i}+u_{a,i+1}&1\le i<n\end{cases}
$$
当 $a=0$ 时:
$$
u_{0,i}=\begin{cases}1&i=n\\0&1\le i<n\end{cases}
$$
实际上,当 $a=1$ 时,$u_{1,i}=1$。所以接下来的这一部分看着很像杨辉三角?
这样,当 $a\ge1$ 时,就可以进一步将递推公式简化成通项公式:
$$
\begin{aligned}
u_{a,i}
&=\mat{a+n-i-1}{a-1}\\
&=\mat{a+n-i-2}{a-1}+\mat{a+n-i-2}{a-2}
\end{aligned}
$$
$$
\begin{aligned}
u_{a+1,i}
&=\mat{a+n-i-1}a+\mat{a+n-i-1}{a-1}\\
&=\mat{a+n-i-1}a+u_{a,i}
\end{aligned}
$$
所以:
$$
T_{a+1}=T_a+\sum_{i=1}^n\mat{a+n-i-1}a
$$
接下来处理后半部分。
$$
\sum_{i=1}^n\mat{a+n-i-1}a=\sum_{j=0}^{n-2}\mat{a+j}a
$$
由于 $\sum_{k=0}^m\mat{r+k}r=\mat{r+m+1}{r+1}$,因此:
$$
\sum_{j=0}^{n-2}\mat{a+j}a=\mat{a+n-1}{a+1}
$$
而:
$$
\begin{aligned}
T_a
&=\sum_{i=1}^n\mat{a+n-i-1}{a-1}\\
&=\mat{a+n-1}{a}
\end{aligned}
$$
所以:
$$
\begin{aligned}
T_{a+1}
&=T_a+\mat{a+n-1}{a+1}\\
&=\frac{a+n}{a+1}\times T_a
\end{aligned}
$$
如果赛时不会推式子,也可以打表,观察相邻两个 $T_a$ 的商,可以得出同样的结论。
---
$T_a$ 解决了,$S_a$ 也解决了,那么整条思路就打通了。
我们已知 $S_0=k^n$,$T_0=1$,接下来只要根据公式 $S_{a+1}=\frac{S_a-T_a}{k-1}-T_a$ 和 $T_{a+1}=\frac{a+n}{a+1}\times T_a$ 转移到 $a=m$ 即可。预处理每个数的逆元,最终时间复杂度 $O(\sum m)$。
---
```cpp
#include <bits/stdc++.h>
using namespace std;
const long long MOD = 998244353;
int T, N, M;
long long K, I[2000006], F[2000006];
long long power(long long x, long long y) {
if(y == 0)
return 1;
if(y == 1)
return x;
long long w = power(x * x % MOD, y >> 1);
if(y & 1)
return w * x % MOD;
else
return w;
}
int main() {
F[0] = 1;
int U = 2000005;
for(int i = 1; i <= U; i ++)
F[i] = F[i - 1] * i % MOD;
I[U] = power(F[U], MOD - 2);
for(int i = U - 1; i; i --)
I[i] = I[i + 1] * (i + 1) % MOD;
for(int i = 1; i <= U; i ++)
I[i] = I[i] * F[i - 1] % MOD;
scanf("%d", &T);
while(T --) {
scanf("%d%d%lld", &N, &M, &K);
if(K == 0) {
puts("0");
continue;
}
long long w = power(K, N), inv = power(K - 1, MOD - 2), C = 1;
for(int i = 0; i < M; i ++) {
if(K > 1)
w = ((w + (w - C) * inv - C) % MOD + MOD) % MOD;
C = C * (N + i) % MOD * I[i + 1] % MOD;
}
printf("%lld\n", K == 1 ? C : w);
}
return 0;
}
```