题解:P11420 [清华集训 2024] 乘积的期望

· · 题解

首先考虑一下 m\leq 16 的做法。

考虑所有元素乘积的组合意义,即为每个位置选择一个覆盖它的操作的方案数。转化一下,就是对每次操作,选择它覆盖的区间内的一个下标子集,满足该子集内的下标都未被覆盖,然后将这些下标覆盖。

那么可以发现子集内下标数量非零的只有 O(n) 次操作(称为关键操作),且对于一个子集 x_1,x_2\dots x_k 来说,可以覆盖这些下标的区间左端点 l\in [x_k-m+1,x_1] 只和 x_1,x_k 有关。那么我们可以从左到右 dp,设 f_{i,j,S} 表示考虑前 i 个位置,当前已经划分了 j 次关键操作,S 是一个 m-1 位的二进制数,表示 [i-m,i-1] 这些位置里哪些位置是作为某个子集的 x_1x_k 仍未选择,转移就是分成如下几类讨论:

最后乘上 \dbinom{C}{j}j!(\sum b)^{C-j} 即可,因为求的是期望最后还要 /(\sum b)^C,复杂度 O(n^2m2^m)

该做法可以说明 C 只在组合数里出现,因此答案是一个关于 Cn 次多项式。

对于 m>\frac{n}{2} 的情况,可以发现中间长度为 2m-n 的部分会被每次操作都覆盖,因此这些位置最后都是 C,删掉即可,结合 m\leq 16 可以得到一个 O(2^{\frac{n}{2}}poly(n)) 的做法,可以通过 n\leq 30 的部分。

接下来考虑一个 m>\frac{n}{3} 时的做法,方便起见把 n 补成 3m

对于最终得到的序列 a,我们将其分成三部分 a_{1}\sim a_{m},a_{m+1}\sim a_{2m},a_{2m+1}\sim a_{3m}

那么我们可以得到如下性质,下面的 i\in [1,m]

可以说明这三个条件就是 a 序列的合法的充要条件:

那么我们可以枚举 a_{2m+1},然后设 f_{i,j,k} 表示考虑了 a_{1}\sim a_{i},a_{m+1}\sim a_{m+i},a_{2m+1}\sim a_{2m+i},当前的 a_i=j,a_{2m+i}=k 的方案数,转移就枚举 j'\geq j,l'\leq k 转移到 f_{i+1,j',k'},复杂度 O(nC^5),注意到可以先转移 j\to j' 再转移 k\to k' 这样就变成 O(nC^4)

根据之前所说,答案是一个关于 Cn 次多项式,那么我们求出前 n+1 项之后插值出原答案即可。

复杂度 O(n^6)

#include<bits/stdc++.h>
using namespace std;
const int N = 180;
int n,m,C;
template <typename T>inline void read(T &x)
{
    x=0;char c=getchar();bool f=0;
    for(;c<'0'||c>'9';c=getchar())f|=(c=='-');
    for(;c>='0'&&c<='9';c=getchar())x=(x<<1)+(x<<3)+(c-'0');
    x=(f?-x:x);
}
const int mod = 998244353;
int b[N],s[N];
int dp[2][60][(1<<19)];
int binom[N];
void add(int &x,int y){x=(x+y>=mod?x+y-mod:x+y);}
int Pow(int a,int b)
{
    if(b<0)return 0;
    int res=1;
    while(b)
    {
        if(b&1)res=1ll*res*a%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return res;
}
int calc(int l,int r)
{
    swap(l,r);
    l=l-m+1;
    l=max(l,1);
    return s[r]-s[l-1];
}
int B[N];
int f[N][N][N],ifac[N],fac[N],pw[N][N];
void init(int k)
{
    ifac[0]=fac[0]=1;
    for(int i=1;i<=k;i++)
    {
        fac[i]=1ll*fac[i-1]*i%mod;
        ifac[i]=Pow(fac[i],mod-2);
    }
    for(int i=1;i<=n;i++)
    {
        pw[i][0]=1;
        for(int j=1;j<=k;j++)
        pw[i][j]=1ll*pw[i][j-1]*b[i]%mod;
    }
}
int X[N],Y[N],g[N][N];
int Lag(int n)
{
    int ans=0;
    for(int i=1;i<=n;i++)
    {
        int v=1;
        for(int j=1;j<=n;j++)
        if(i!=j)v=1ll*v*(C-j)%mod*Pow(i-j,mod-2)%mod;
        ans=(ans+1ll*v*Y[i]%mod)%mod;
    }
    ans=(ans%mod+mod)%mod;
    return ans;
}
int lg[(1<<19)];
int main()
{
    read(n);read(m);read(C);
    for(int i=1;i<=n-m+1;i++)read(b[i]);
    if(m==n)
    {
        cout<<Pow(C,n);
        return 0;
    }
    binom[0]=1;
    for(int i=1;i<=n;i++)
    binom[i]=1ll*binom[i-1]*(C-i+1)%mod;
    int alt=1;
    if(2*m>n)
    {
        int l=2*m-n;
        alt=Pow(C,l);
        m-=l;n-=l;
    }
    for(int i=1;i<=n;i++)s[i]=s[i-1]+b[i];
    if(m==1)
    {
        int ans=1ll*binom[n]*Pow(s[n],C-n)%mod;
        for(int i=1;i<=n;i++)ans=1ll*ans*b[i]%mod;
        ans=1ll*ans*Pow(Pow(s[n],mod-2),C)%mod;
        cout<<1ll*ans*alt%mod;
        return 0;
    }
    if(m<=16)
    {       
        dp[0][0][0]=1;
        int cur=0;
        for(int i=0;i<=m-1;i++)lg[1<<i]=i;
        for(int i=1;i<=n;i++)
        {
            int nxt=(cur^1);
            for(int j=0;j<=i-1;j++)
            for(int s=0;s<(1<<(m-1));s++)
            if(dp[cur][j][s])
            {
                if((s>>(m-2))&1)add(dp[nxt][j][(s-(1<<(m-2)))<<1],1ll*dp[cur][j][s]*calc(i-m+1,i)%mod);
                else 
                {
                    for(int t=s;t;t-=(t&-t))
                    {
                        int k=lg[t&-t]+1;
                        add(dp[nxt][j][s<<1],dp[cur][j][s]);
                        add(dp[nxt][j][(s-(1<<(k-1)))<<1],1ll*dp[cur][j][s]*calc(i-k,i)%mod);
                    }
                    add(dp[nxt][j+1][s<<1|1],dp[cur][j][s]);
                    add(dp[nxt][j+1][s<<1],1ll*dp[cur][j][s]*calc(i,i)%mod);
                }
                dp[cur][j][s]=0;
            }
            cur=nxt;
        }
        int ans=0;
        for(int i=1;i<=n;i++)
        add(ans,1ll*dp[cur][i][0]*binom[i]%mod*Pow(s[n],C-i)%mod);
        ans=1ll*ans*Pow(Pow(s[n],mod-2),C)%mod;
        cout<<1ll*ans*alt%mod;
        return 0;
    }
    int L=n;
    n=3*m;
    for(int i=1;i<=n;i++)
    {
        s[i]=s[i-1]+b[i];
    }
    init(n*2);
    for(int c=1;c<=L+1;c++)
    {
        int ret=0;
        for(int w=0;w<=c;w++)
        {
            for(int i=1;i<=m;i++)
            for(int j=0;j<=c;j++)
            for(int k=0;k<=c;k++)
            f[i][j][k]=0;
            auto val = [&](int i,int v)
            {
                if(i<=L)return v;
                return 1;
            };
            for(int j=0;j<=c-w;j++)
            f[1][j][w]=1ll*ifac[j]*pw[1][j]%mod*val(1,j)%mod*val(2*m+1,w)%mod*val(m+1,c-w-j)%mod;
            for(int i=2;i<=m;i++)
            {           
                for(int j=0;j<=c;j++)
                for(int k=0;k<=c-j;k++)
                if(f[i-1][j][k])
                {
                    for(int nj=j;nj<=c-w;nj++)
                    {
                        int v=f[i-1][j][k];
                        v=1ll*v*ifac[nj-j]%mod;
                        v=1ll*v*pw[i][nj-j]%mod;
                        add(g[nj][k],v);
                    }
                }
                for(int nj=0;nj<=c-w;nj++)
                for(int k=0;k<=c;k++)
                if(g[nj][k])
                {
                    for(int nk=0;nk<=min(k,c-nj);nk++)
                    {
                        int v=g[nj][k];
                        v=1ll*v*ifac[k-nk]%mod;
                        v=1ll*v*pw[m+i][k-nk]%mod;
                        v=1ll*v*val(i,nj)%mod*val(m+i,c-nj-nk)%mod*val(2*m+i,nk)%mod;
                        add(f[i][nj][nk],v);
                    }
                    g[nj][k]=0;
                }
            }
            for(int j=0;j<=c-w;j++)
            for(int k=0;k<=w;k++)
            if(f[m][j][k])
            add(ret,1ll*f[m][j][k]*ifac[k]%mod*pw[2*m+1][k]%mod*ifac[c-j-w]%mod*pw[m+1][c-j-w]%mod);
        }
        ret=1ll*ret*fac[c]%mod;
        ret=1ll*ret*Pow(Pow(s[n],c),mod-2)%mod;
        X[c]=c;
        Y[c]=ret;
    }
    int ans=Lag(L+1);
    cout<<1ll*ans*alt%mod;
    return 0;
}
/* 
4 2 5000
1 1 1
*/