题解 P6633 【[ZJOI2020] 抽卡】

· · 题解

定义存在连续 k 张卡的状态为合法状态,不存在的为非法状态。

朴素的想法是考虑状压,因为状态转移没有比自环大的环,且合法状态不会转化成非法状态,所以第一次出现合法状态的期望时间可以表示成所有非法状态的出现概率乘以到下次转移的期望时间。

如果状态中抽了 r 张卡,那么这个状态有 \frac{1}{\binom{m}{r}} 的概率出现,并且期望经过 \frac{m}{m-r} 的时间后会转移到下个状态。所以只需要计算出取 r 张牌的非法状态数量即可。

对每一个连续段分别计数,假设当前连续段的长度为 n,那么选 r 张牌的非法状态数为:

[x^{n+1-r}](\sum_{i=0}^{k-1}x^i)^r

就是每次加入一段连续的 1 和紧跟着的一个 01 的数量可以直接用加入的段数,即 r 进行描述。

G(x)=\sum_{i=1}^{k}x^i,那么就是求:

[x^{n+1}]\frac{1}{1-uG(x)}

二元生成函数也是可以拉格朗日反演的,所以直接上扩展拉格朗日反演:

[x^{n+1}]\frac{1}{1-uG(x)}=\frac{1}{n+1}[x^n]\frac{u}{(1-ux)^2}(\frac{x}{G^{-1}(x)})^{n+1}

这样挂在 u 上面的 x 只有一项,就只剩下最右边的一项需要计算了。

因为有 G(x)=\frac{x^{k+1}-x}{x-1},所以求 G^{-1}(x) 等价于求一 F(x) 使得:

\frac{F(x)^{k+1}-F(x)}{F(x)-1}=x\Leftrightarrow (1+x)F(x)-F(x)^{1+k}-x=0

牛顿迭代即可在 O(nlogn) 的时间内求出。

最后需要将每一段对应的生成函数乘起来,所以总复杂度为 O(mlog^2m)

#include<bits/stdc++.h>
using namespace std;
#define I inline int
#define V inline void
#define ll long long int
#define FOR(i,a,b) for(int i=a;i<=b;i++)
#define ROF(i,a,b) for(int i=a;i>=b;i--)
const int N=1<<20,mod=998244353;
int m,k,tot,lmt,ans,a[N];
int big[N<<5],*np(big),*t[N],f[N];
int w[N],r[N],fac[N],inv[N],ifac[N],len[N];
V check(int&x){x-=mod,x+=x>>31&mod;}
I Pow(ll t,int x,ll s=1){
    for(;x;x>>=1,t=t*t%mod)if(x&1)s=s*t%mod;
    return s;
}
V cl(int*op,int*ed){memset(op,0,ed-op<<2);}
I getLen(int n){return 1<<32-__builtin_clz(n);}
V mul(int*a,ll k,int l,int*b){while(l--)*b++=k**a++%mod;}
I C(int x,int y){return 1ll*fac[x]*ifac[y]%mod*ifac[x-y]%mod;}
V dot(int*a,int*b,int l,int*c){while(l--)*c++=1ll**a++**b++%mod;}
V init(int n){
    int l=-1,wn;
    for(lmt=inv[0]=inv[1]=fac[0]=ifac[0]=1;lmt<=n;lmt<<=1)l++;
    FOR(i,2,n-1)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    FOR(i,1,n-1)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=1ll*ifac[i-1]*inv[i]%mod;
    FOR(i,1,lmt-1)r[i]=r[i>>1]>>1|(i&1)<<l;
    wn=Pow(3,mod>>++l),w[lmt>>1]=1;
    FOR(i,lmt/2+1,lmt-1)w[i]=1ll*w[i-1]*wn%mod;
    ROF(i,lmt/2-1,1)w[i]=w[i<<1];
}
V DFT(int*a,int l){
    static unsigned long long int tmp[N];
    int t,u=__builtin_ctz(lmt/l);
    FOR(i,0,l-1)tmp[i]=a[r[i]>>u];
    for(int i=1;i<l;i<<=1)for(int j=0,d=i<<1;j<l;j+=d)FOR(k,0,i-1)
        t=tmp[i|j|k]*w[i|k]%mod,tmp[i|j|k]=tmp[j|k]+mod-t,tmp[j|k]+=t;
    FOR(i,0,l-1)a[i]=tmp[i]%mod;
}
V IDFT(int*a,int l){reverse(a+1,a+l),DFT(a,l),mul(a,mod-mod/l,l,a);}
V Inv(const int*a,int n,int*b){
    if(n==1)return void(b[0]=Pow(a[0],mod-2));
    static int tmp[N],l;
    Inv(a,n+1>>1,b),copy(a,a+n,tmp),DFT(tmp,l=getLen(n<<1)),DFT(b,l);
    FOR(i,0,l-1)b[i]=(2ll+mod-1ll*tmp[i]*b[i]%mod)*b[i]%mod;
    IDFT(b,l),cl(b+n,b+l),cl(tmp,tmp+l);
}
V conv(int*a,int*b,int l,int*c){DFT(a,l),DFT(b,l),dot(a,b,l,c),IDFT(c,l);}
V deri(const int*a,int n,int*b){FOR(i,1,n-1)b[i-1]=1ll*i*a[i]%mod;b[n-1]=0;}
V inte(const int*a,int n,int*b){FOR(i,1,n-1)b[i]=1ll*inv[i]*a[i-1]%mod;b[0]=0;}
V Ln(const int*a,int n,int*b){
    static int tmp[N],l;
    l=getLen(n<<1),Inv(a,n,tmp),deri(a,n,b);
    conv(tmp,b,l,tmp),inte(tmp,n,b),cl(b+n,b+l),cl(tmp,tmp+l);
}
V Exp(const int*a,int n,int*b){
    if(n==1)return void(b[0]=1);
    static int tmp[N],l;
    Exp(a,n+1>>1,b),Ln(b,n,tmp),tmp[0]--;
    FOR(i,0,n-1)check(tmp[i]=a[i]+mod-tmp[i]);
    conv(b,tmp,l=getLen(n<<1),b),cl(b+n,b+l),cl(tmp,tmp+l);
}
V Mul(const int*a,const int*b,int n,int m,int*c){
    static int X[N],Y[N],l;
    l=getLen(n+m-1),copy(a,a+n,X),copy(b,b+m,Y);
    conv(X,Y,l,X),copy(X,X+n+m-1,c),cl(X,X+l),cl(Y,Y+l);
}
V Pow(const int*a,int n,int k,int*b){
    static int tmp[N];
    if(n>1)Ln(a,n,tmp),mul(tmp,k,n,tmp),Exp(tmp,n,b),cl(tmp,tmp+n);
}
V solve(int n){
    if(n==1)return void(f[0]=0);
    static int tmp[N],up[N],dw[N],l;
    solve(n+1>>1),l=getLen(n<<1),Pow(f+1,n-k,k,tmp+k),mul(tmp,mod-1,n,tmp);
    FOR(i,0,n-1)dw[i]=1ll*(k+1)*tmp[i]%mod;
    dw[0]++,dw[1]++,tmp[0]++,tmp[1]++,Mul(f,tmp,n,n,up),check(up[1]+=mod-1);
    cl(tmp,tmp+l),Inv(dw,n,tmp),cl(up+n,up+n+n-1),conv(tmp,up,l,tmp);
    FOR(i,0,n-1)check(f[i]+=mod-tmp[i]);
    cl(tmp,tmp+l),cl(up,up+l),cl(dw,dw+l);
}
V calc(int*g,int n){
    Pow(f+1,n+1,mod-n-1,g);
    FOR(i,0,n)g[i]=1ll*(n-i+1)*g[i]%mod*inv[n+1]%mod;
}
int main(){
    scanf("%d%d",&m,&k);
    FOR(i,1,m)scanf("%d",a+i);
    sort(a+1,a+m+1),init(m+2<<1);
    for(int p=1,d=0;p<=m;p+=d,len[++tot]=d+1)for(d=0;a[p+d]==a[p]+d;d++);
    sort(len+1,len+1+tot),solve(len[tot]+1);
    FOR(i,1,tot)
        if(t[i]=np,np+=len[i],len[i]==len[i-1])copy(t[i-1],t[i-1]+len[i],t[i]);
        else calc(t[i],len[i]-1);
    for(int x=1,y;y=x+1,x<tot;len[tot]=len[x]+len[y]-1,np+=len[tot],x+=2)
        t[++tot]=np,Mul(t[x],t[y],len[x],len[y],t[tot]);
    FOR(i,0,m)check(ans+=1ll*t[tot][i]*Pow(1ll*C(m,i)*(m-i)%mod,mod-2,m)%mod);
    return cout<<ans,0;
}