题解 CF1349D 【Slime and Biscuits】
ZigZagKmp
2021-03-13 15:21:23
无限sto dls
### 题意简述
第 $i$ 个人初始有 $a_i$ 个饼干,现在进行若干次操作,每一次操作会从所有饼干中随机一块,然后把这块饼干随机分给除了原主人以外的某个人。当所有饼干集中在一个人手里时,结束操作。问操作的期望次数,对 $998244353$ 取模。
$2\le n\le 100000\ ,\ 0\le a_i\le 300000$
### 算法分析
这里介绍一下dls发明的势能法,想法更自然,且可以作为解决这一类问题的通法。
我们考虑构造一个函数 $f$ ,满足进行一步操作之后 $\sum\limits_{i=1}^nf(a_i)$ 会 $-1$ ,则操作的期望次数就是 **初状态势函数和-末状态势函数和** ,本题即为:
$$\sum\limits_{i=1}^nf(a_i)-(n-1)f(0)-f(m)$$
考虑每一次操作,我们枚举选到的饼干是哪一个人的,再枚举这个饼干后来分给了谁,即:
$$\sum\limits_{i=1}^nf(a_i)=1+\sum\limits_{i=1}^n\frac{a_i}{m}\left(f(a_i-1)+\sum\limits_{j\neq i}\frac{1}{n-1}f(a_j+1)+\frac{n-2}{n-1}f(a_j)\right)$$
前面的 $1$ 是我们对势函数的要求,后面第 $i$ 个人的饼干被选中的概率是 $\frac{a_i}{m}$,在这次操作后,第 $i$ 个人的饼干个数会变成 $a_i-1$ ,再枚举这个饼干给了谁,剩下的 $n-1$ 个人每一个人都有 $\frac{1}{n-1}$ 的概率拿到这一块饼干,变成 $a_j+1$ ,其余情况不变。
我们稍微化一下这个式子:
$$\sum\limits_{i=1}^nf(a_i)=\sum\limits_{i=1}^n\frac{a_i}{m}(f(a_i-1)+1)+\frac{m-a_i}{m(n-1)}f(a_i+1)+\frac{(m-a_i)(n-2)}{m(n-1)}f(a_i)$$
这里我们把前面的 $+1$ 放到了 $\sum\limits_{i=1}^n\frac{a_i}{m}(f(a_i-1)+\color{red}1)$ ,不难证明这个是等价的。
这个方程看上去没啥用,但是因为我们只要**能构造出一个这样的势函数就可以**,因此不妨给它加强一下,**让每一个 $i$ 前后项都对应相等**,即:
$$f(a)=\frac{a}{m}(f(a-1)+1)+\frac{m-a}{m(n-1)}f(a+1)+\frac{(m-a)(n-2)}{m(n-1)}f(a)$$
这样我们就可以得到 $m+1$ 个这样的方程,下面可以用 ``band-matrix`` 或者用 $f(i)=k_if(m)+b_i$ 的方法解出这个方程组,时间复杂度是 $O(n)$ 或者 $O(n\log \mathrm{MOD})$ 的,取决于消元过程中用快速幂求逆元还是预处理出逆元。
### 代码实现
我写的是第二种方法,且没有预处理逆元,时间复杂度 $O(n\log n)$ 。
```cpp
#include<bits/stdc++.h>
using namespace std;
#define maxn 1000005
#define maxm 2000005
#define inf 0x3f3f3f3f
#define LL long long
#define ull unsigned long long
#define mod 998244353
#define eps 1e-6
#define local
void file(string s){freopen((s+".in").c_str(),"r",stdin);freopen((s+".out").c_str(),"w",stdout);}
template <typename Tp> void read(Tp &x){
int fh=1;char c=getchar();x=0;
while(c>'9'||c<'0'){if(c=='-'){fh=-1;}c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c&15);c=getchar();}x*=fh;
}
int ksm(int b,int p,int Mod=mod){int ret=1;while(p){if(p&1)ret=1ll*ret*b%Mod;b=1ll*b*b%Mod;p>>=1;}return ret;}
int n,m;
int a[maxn],f[maxn];
struct node{//维护kf(m)+b的一个类型
int k,b;
node operator +(node y)const{
return (node){(k+y.k)%mod,(b+y.b)%mod};
}
node operator -(node y)const{
return (node){(k+mod-y.k)%mod,(b+mod-y.b)%mod};
}
node operator *(int y)const{
return (node){1ll*k*y%mod,1ll*b*y%mod};
}
node operator +(int y)const{
return (node){k,(b+y)%mod};
}
node operator -(int y)const{
return (node){k,(b+mod-y)%mod};
}
};
node ff[maxn];
signed main(){
read(n);
for(int i=1;i<=n;++i)read(a[i]),m+=a[i];
ff[m]=(node){1,0};
for(int i=m-1;i>=1;--i){//代入消元,求出每一个f(i)的 kf(m)+b 表示
int di=1ll*(m-i)*ksm(1ll*m*(n-1)%mod,mod-2)%mod;
int di1=1ll*i*ksm(m,mod-2)%mod;
int di2=1ll*(m-i)*(n-2)%mod*ksm(1ll*m*(n-1)%mod,mod-2)%mod;
ff[i-1]=ff[i]-ff[i]*di2-ff[i+1]*di;
ff[i-1]=ff[i-1]-di1;
ff[i-1]=ff[i-1]*ksm(di1,mod-2);
}//边界(手算可以发现)f(0)=f(1),直接解方程
int f0=1ll*(ff[0].b-ff[1].b)*ksm(ff[1].k-ff[0].k,mod-2)%mod;
for(int i=0;i<=m;++i){//还原f(i)
f[i]=(1ll*f0*ff[i].k%mod+ff[i].b)%mod;
f[i]=(f[i]+mod)%mod;
}
int ans=0;//求答案
for(int i=1;i<=n;++i)ans=(ans+f[a[i]])%mod;
ans=(ans-1ll*(n-1)*f[0]%mod)%mod;
ans=(ans-f[m])%mod;
ans=(ans+mod)%mod;
printf("%d\n",ans);
return 0;
}
```