题解 P6633 【[ZJOI2020] 抽卡】
定义存在连续
朴素的想法是考虑状压,因为状态转移没有比自环大的环,且合法状态不会转化成非法状态,所以第一次出现合法状态的期望时间可以表示成所有非法状态的出现概率乘以到下次转移的期望时间。
如果状态中抽了
对每一个连续段分别计数,假设当前连续段的长度为
就是每次加入一段连续的
设
二元生成函数也是可以拉格朗日反演的,所以直接上扩展拉格朗日反演:
这样挂在
因为有
牛顿迭代即可在
最后需要将每一段对应的生成函数乘起来,所以总复杂度为
#include<bits/stdc++.h>
using namespace std;
#define I inline int
#define V inline void
#define ll long long int
#define FOR(i,a,b) for(int i=a;i<=b;i++)
#define ROF(i,a,b) for(int i=a;i>=b;i--)
const int N=1<<20,mod=998244353;
int m,k,tot,lmt,ans,a[N];
int big[N<<5],*np(big),*t[N],f[N];
int w[N],r[N],fac[N],inv[N],ifac[N],len[N];
V check(int&x){x-=mod,x+=x>>31&mod;}
I Pow(ll t,int x,ll s=1){
for(;x;x>>=1,t=t*t%mod)if(x&1)s=s*t%mod;
return s;
}
V cl(int*op,int*ed){memset(op,0,ed-op<<2);}
I getLen(int n){return 1<<32-__builtin_clz(n);}
V mul(int*a,ll k,int l,int*b){while(l--)*b++=k**a++%mod;}
I C(int x,int y){return 1ll*fac[x]*ifac[y]%mod*ifac[x-y]%mod;}
V dot(int*a,int*b,int l,int*c){while(l--)*c++=1ll**a++**b++%mod;}
V init(int n){
int l=-1,wn;
for(lmt=inv[0]=inv[1]=fac[0]=ifac[0]=1;lmt<=n;lmt<<=1)l++;
FOR(i,2,n-1)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
FOR(i,1,n-1)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=1ll*ifac[i-1]*inv[i]%mod;
FOR(i,1,lmt-1)r[i]=r[i>>1]>>1|(i&1)<<l;
wn=Pow(3,mod>>++l),w[lmt>>1]=1;
FOR(i,lmt/2+1,lmt-1)w[i]=1ll*w[i-1]*wn%mod;
ROF(i,lmt/2-1,1)w[i]=w[i<<1];
}
V DFT(int*a,int l){
static unsigned long long int tmp[N];
int t,u=__builtin_ctz(lmt/l);
FOR(i,0,l-1)tmp[i]=a[r[i]>>u];
for(int i=1;i<l;i<<=1)for(int j=0,d=i<<1;j<l;j+=d)FOR(k,0,i-1)
t=tmp[i|j|k]*w[i|k]%mod,tmp[i|j|k]=tmp[j|k]+mod-t,tmp[j|k]+=t;
FOR(i,0,l-1)a[i]=tmp[i]%mod;
}
V IDFT(int*a,int l){reverse(a+1,a+l),DFT(a,l),mul(a,mod-mod/l,l,a);}
V Inv(const int*a,int n,int*b){
if(n==1)return void(b[0]=Pow(a[0],mod-2));
static int tmp[N],l;
Inv(a,n+1>>1,b),copy(a,a+n,tmp),DFT(tmp,l=getLen(n<<1)),DFT(b,l);
FOR(i,0,l-1)b[i]=(2ll+mod-1ll*tmp[i]*b[i]%mod)*b[i]%mod;
IDFT(b,l),cl(b+n,b+l),cl(tmp,tmp+l);
}
V conv(int*a,int*b,int l,int*c){DFT(a,l),DFT(b,l),dot(a,b,l,c),IDFT(c,l);}
V deri(const int*a,int n,int*b){FOR(i,1,n-1)b[i-1]=1ll*i*a[i]%mod;b[n-1]=0;}
V inte(const int*a,int n,int*b){FOR(i,1,n-1)b[i]=1ll*inv[i]*a[i-1]%mod;b[0]=0;}
V Ln(const int*a,int n,int*b){
static int tmp[N],l;
l=getLen(n<<1),Inv(a,n,tmp),deri(a,n,b);
conv(tmp,b,l,tmp),inte(tmp,n,b),cl(b+n,b+l),cl(tmp,tmp+l);
}
V Exp(const int*a,int n,int*b){
if(n==1)return void(b[0]=1);
static int tmp[N],l;
Exp(a,n+1>>1,b),Ln(b,n,tmp),tmp[0]--;
FOR(i,0,n-1)check(tmp[i]=a[i]+mod-tmp[i]);
conv(b,tmp,l=getLen(n<<1),b),cl(b+n,b+l),cl(tmp,tmp+l);
}
V Mul(const int*a,const int*b,int n,int m,int*c){
static int X[N],Y[N],l;
l=getLen(n+m-1),copy(a,a+n,X),copy(b,b+m,Y);
conv(X,Y,l,X),copy(X,X+n+m-1,c),cl(X,X+l),cl(Y,Y+l);
}
V Pow(const int*a,int n,int k,int*b){
static int tmp[N];
if(n>1)Ln(a,n,tmp),mul(tmp,k,n,tmp),Exp(tmp,n,b),cl(tmp,tmp+n);
}
V solve(int n){
if(n==1)return void(f[0]=0);
static int tmp[N],up[N],dw[N],l;
solve(n+1>>1),l=getLen(n<<1),Pow(f+1,n-k,k,tmp+k),mul(tmp,mod-1,n,tmp);
FOR(i,0,n-1)dw[i]=1ll*(k+1)*tmp[i]%mod;
dw[0]++,dw[1]++,tmp[0]++,tmp[1]++,Mul(f,tmp,n,n,up),check(up[1]+=mod-1);
cl(tmp,tmp+l),Inv(dw,n,tmp),cl(up+n,up+n+n-1),conv(tmp,up,l,tmp);
FOR(i,0,n-1)check(f[i]+=mod-tmp[i]);
cl(tmp,tmp+l),cl(up,up+l),cl(dw,dw+l);
}
V calc(int*g,int n){
Pow(f+1,n+1,mod-n-1,g);
FOR(i,0,n)g[i]=1ll*(n-i+1)*g[i]%mod*inv[n+1]%mod;
}
int main(){
scanf("%d%d",&m,&k);
FOR(i,1,m)scanf("%d",a+i);
sort(a+1,a+m+1),init(m+2<<1);
for(int p=1,d=0;p<=m;p+=d,len[++tot]=d+1)for(d=0;a[p+d]==a[p]+d;d++);
sort(len+1,len+1+tot),solve(len[tot]+1);
FOR(i,1,tot)
if(t[i]=np,np+=len[i],len[i]==len[i-1])copy(t[i-1],t[i-1]+len[i],t[i]);
else calc(t[i],len[i]-1);
for(int x=1,y;y=x+1,x<tot;len[tot]=len[x]+len[y]-1,np+=len[tot],x+=2)
t[++tot]=np,Mul(t[x],t[y],len[x],len[y],t[tot]);
FOR(i,0,m)check(ans+=1ll*t[tot][i]*Pow(1ll*C(m,i)*(m-i)%mod,mod-2,m)%mod);
return cout<<ans,0;
}