题解 CF1278F 【Cards】

NaCly_Fish

2020-01-29 03:13:18

Solution

在我的博客内看体验更佳( updated:改为 $\Theta(k)$ 的解法。 ps:这个复杂度的做法 [iostream](https://www.luogu.com.cn/user/13052) 比我先做出来。 设 $p=1/m$,$q=1-p$,那么枚举第一张为王牌的次数,有这么一个暴力计算的式子 $$\sum_{i=0}^n\binom{n}{i}p^iq^{n-i}i^k$$ 后面那个 $i^k$ 可以展开为第二类 Stirling 数 $$\sum_{i=0}^n\binom{n}{i}p^iq^{n-i}\sum_{j=1}^k \begin{Bmatrix} k \\ j \end{Bmatrix}i^{\underline j}$$ $$=\sum_{j=1}^k\begin{Bmatrix} k \\ j \end{Bmatrix}\sum_{i=0}^n\binom{n}{i}p^iq^{n-i}i^{\underline j}$$ 考虑化简后面那个式子 $$=\sum_{i=j}^n \binom ni p^i(1-p)^{n-i}\binom{i}{j}j!$$ $$=j!\binom{n}{j}\sum_{i=j}^n\binom{n-j}{i-j}p^i(1-p)^{n-i}$$ $$=n^{\underline j} \sum_{i=0}^{n-j}\binom{n-j}{i}p^{i+j}(1-p)^{n-j-i}$$ $$=n^{\underline j}p^j$$ 然后直接得到原式为 $$\sum_{j=0}^k\begin{Bmatrix} k \\ j \end{Bmatrix}n^{\underline j}p^j$$ 这样就可以 $\Theta(k \log k)$ 计算了,然而还可以更优。 暴力拆开原式的 Stirling 数得到 $$ \sum_{j=0}^k\frac{1}{j!}\sum_{i=0}^j(-1)^{j-i}\binom{j}{i} i^k j! \binom nj p^j$$ $$=\sum_{i=0}^ki^k\sum_{j=i}^k(-1)^{j-i} \binom nj \binom ji p^j$$ 根据组合数的基本意义可以化为 $$\sum_{i=0}^ki^k\binom ni\sum_{j=i}^k(-1)^{j-i}\binom{n-i}{j-i}p^j$$ 后面一个和式改为枚举从 $0$ 到 $k-i$,所有含 $j$ 的式子都 $+i$ 得到 $$\sum_{i=0}^ki^k\binom ni p^i\sum_{j=0}^{k-i}(-1)^j\binom{n-i}{j}p^j$$ 设后面那个和式为 $f(i)$,显然有 $f(k)=1$。 做个差分,再做一些麻烦的推式子,就能得到如下关系: (我的推法又臭又长,而且也没什么技术含量,就不放上来了) $$f(i)=(-p)^{k-i}\binom{n-i-1}{k-i}+(1-p)f(i+1)$$ 用线性筛求出 $i^k \ (i\in [1,k])$,时间复杂度就可以做到 $\Theta(k)$。 要注意的是 $n \le k$ 的时候会有点问题,直接用最开始的式子,暴力计算即可。 ```cpp #include<cstdio> #include<iostream> #include<algorithm> #include<cstring> #include<cmath> #define N 10000003 #define ll long long #define p 998244353 #define reg register using namespace std; inline int power(int a,int t){ int res = 1; while(t){ if(t&1) res = (ll)res*a%p; a = (ll)a*a%p; t >>= 1; } return res; } int n,m,k,cnt; int f[N],inv[N],pw[N],pr[N>>1],c[N]; int solve1(){ int mul = (ll)power(p+1-m,n-1)*m%p,invq = power(p+1-m,p-2),res = 0; for(reg int i=1;i<=n;++i){ res = (res+(ll)mul*c[i]%p*pw[i])%p; mul = (ll)mul*invq%p*m%p; } return res; } int solve2(){ int c2,mul,res = 0; mul = c2 = f[k] = 1; for(reg int i=k-1;i;--i){ c2 = (ll)c2*(n-i-1)%p*inv[k-i]%p; mul = (ll)mul*(p-m)%p; f[i] = ((ll)c2*mul+(ll)(p+1-m)*f[i+1])%p; } mul = m; for(reg int i=1;i<=k;++i){ res = (res+(ll)pw[i]*c[i]%p*mul%p*f[i])%p; mul = (ll)mul*m%p; } return res; } int main(){ scanf("%d%d%d",&n,&m,&k); m = power(m,p-2); c[0] = inv[1] = pw[1] = 1; c[1] = n; for(reg int i=2;i<=k;++i){ inv[i] = (ll)(p-p/i)*inv[p%i]%p; c[i] = (ll)c[i-1]*inv[i]%p*(n-i+1)%p; if(!pw[i]){ pr[++cnt] = i; pw[i] = power(i,k); } for(reg int j=1;j<=cnt&&(ll)i*pr[j]<=k;++j){ pw[i*pr[j]] = (ll)pw[i]*pw[pr[j]]%p; if(i%pr[j]==0) break; } } if(n<=k+1) printf("%d",solve1()); else printf("%d",solve2()); return 0; } ```