题解:P11420 [清华集训 2024] 乘积的期望

· · 题解

Trick:对于定长度区间操作经常可以转化为网格图平衡复杂度,印象中 QOJ4893. Imbalance 也是的。

考虑将 n 个元素划分成如下的网格图(n 不是 m 的倍数的时候就补 b_i=0 凑成一个完整的图)。

可以发现每次选择 m 个连续的位置一定是某一行的后缀搭配上下一行的前缀,如图中蓝色部分。

有两种转移方式,一种是从上往下行间转移,另一种是从左往右在列上面进行转移。如果我们能找到两种转移方式就能平衡复杂度。

首先,将所有元素乘积的期望转化成组合意义,也就是对于所有覆盖方案我们都对于每个位置选择一种覆盖其的操作的总方案数。我们可以对着这个进行 DP。

竖向 DP

思考一下 DP 需要记录哪些状态,首先有记录当前的位置 i,其次还有当前有几条操作线段延伸过来了,可以我们发现这还不够,无法刻画长度为 m 这个约束。于是把这个改为一个 m 位的二进制数 s,记录该位置往前的 m 个位置中有哪些作为线段的开头,这还不行,因为可能有多个线段共用一个开头,其实这个约束可以改为选这个线段的第一个元素。所以就变成了,某个位置是否是一条线段中被选择的第一个位置,这样子在一个位置上只会和一条线段有关,就可以用 01 变量记录了,这样子我们只需要约束选这个线段的第一个元素和最后一个元素距离 \le k 即可。

这样子我们就可以根据信息开闭线段了。同时,我们发现上述过程并没有对于所有线段作出区分,于是记录一下我们已经选了的线段数量 j,这样子最后为了让这些线段可区分,我们就乘以 A^j_C\times (\sum b_i)^{m-j} 的系数。最后除以 (\sum b_i)^m 即可。

设计好了状态,现在需要进行 dp。考虑当前这一个位置 i,我们要为其选择一个覆盖它的线段(不考虑无线段覆盖它的情况,因为这种情况的贡献为 0)。

首先,如果 i-m+1 这个位置被选择了开头第一个元素,但是没有选择收尾元素,我们就必须选择这条线段,因为这条线段最远也就延伸到 i,而我们加入的时候已经钦定过其有结尾元素了,所以必须要为它找一个结尾元素,那么就只能选择 i 了。

否则,我们可以自由选择。第一个选择就是在 i 这个位置新开一条线段(不一定以 i 作为左端点,但是 i 一定是其第一个元素),注意这个线段有两种选择,要么在这里终结,也就是说开头结尾都是 i;要么往后延伸,也就说我们还要为其找一个 \neq i 的结尾。

如果不选择新开线段,我们可以为其挑选一个之前开的线段,钦定其覆盖 i。这里还有两种选择,一种是这条线段在 i 终结,也就是说 i 是其最后一个元素,注意由于我们加入线段的时候没有为其选择左端点 l,所以在终结的时候要选择一个,假设其第一个位置是 j,那么左端点的范围就是 [i-m+1,j],求一下区间 \sum b_i 作为系数即可。如果这条线段不在这里终结,那很简单,正常转移就行了。

分析一下时间复杂度,状态数 O(nc2^m)。注意到最多有 O(n) 条线段被选择,所以状态数可以被优化到 O(n^22^m),转移是 O(m) 的,所以是 O(n^2m2^m)。当 2m>n 的时候,中间的 2m-n 个元素会被所有操作都覆盖一次,所以这部分贡献可以直接算的,是 c^{2m-n},然后把这些部分去掉,就可以让 2m\le n 了。所以指数可以对于 \dfrac{n}{2}\min,最后的复杂度就是 O(n^2m2^{\min(m,\frac{n}{2})})。分析一下分数表格,发现在指数为 16 左右的时候是可以过的,因此这个做法是可以做 m\le 16 或者 n\le30 的测试点。这样子可以得到 60\rm pts

横向 DP

上述做法可以通过 m\le 16 或者 n\le 30 的部分分,所以剩下的数据范围就是 m\ge 1731\le n\le 50,可以发现这个时候本题解配图中的 k=\lceil \dfrac{n}{m}\rceil=3,也就是说网格图分为三列。而我们横着做 DP,就是每次的状态是一列。

尝试观察一些刻画方式,可以发现每次操作必然是对于每一列的三个位置中恰好有一个位置被操作。所以 a_i+a_{i+m}+a_{i+2m}=m

还有就是基本上所有线段都是跨越两行的,这会导致交界处被覆盖次数很多。所以第一行中的 a_i 单调递增,第三行中的 a_i 单调递减。

这两个条件足够了吗?并不是的。还有一个约束就是你会发现,要么一次操作是行内的,要么一二联动,要么二三联动,唯独没有一三联动的情况,所以 a_m+a_{2m+1}\le C

可以证明这三个条件已经是充要条件了。可以对于这个结构按照列进行 DP。

对于第三个约束,我们直接外层枚举 a_m 就可以解决了。对于第一个约束告诉了我们,我们只需要记录三个位置中的其中两个位置的值,然后第三个位置的值就可以通过做差直接得到了。那么是记录哪两个位置呢?为了方便限制第二个约束,我们记录第一行和第三行的就就行了。设 f_{i,j,k} 表示 dp 到了第 i 列,其中 a_i=j,a_{i+2m}=k 的方案数。注意由于是区分顺序的,所以外面有一个 C!,但是如果遇到有 i 个相同起点的操作,要除以 \dfrac{1}{i!}

转移的时候,枚举 j'\ge jk'\le k 进行转移即可。直接做的时间复杂度是 O(nC^5)

有一个小 trick 就是,你发现 j,k 之间是相对独立的,所以你可以分开转移 j,k,也就是说 (j,k)\to (j',k)\to (j',k'),这样子就优化到了 O(nC^4)。可以拿到 90 \rm pts

发现瓶颈在于 C 很大,有一个观察就是你在竖向 DP 的时候,会发现整个 DP 过程都是和 C 无关的,只有最后乘以组合数的时候才涉及 C,最后是一个关于 C1n 次的多项式求和,所以最终的答案关于 C 是一个 n 次多项式。于是我们算出 C\in [1,n+1] 的答案,然后进行拉格朗日插值即可。

这部分的时间复杂度是 O(n^6)

两个做法综合起来可以通过本题。

#include<bits/stdc++.h>
#define pb emplace_back
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef long long ll;
const int maxm=17;
const int maxn=60;
const int mod=998244353;
void add(int &x,int y){ x=x+y>=mod?x+y-mod:x+y; }
void sub(int &x,int y){ x=x<y?x-y+mod:x-y; }
void cmax(int &x,int y){ x=x>y?x:y; }
void cmin(int &x,int y){ x=x<y?x:y; }
int b[maxn],pre[maxn],fac[maxn],h[maxn],y[maxn],n,m,C,F=1,ans=0;
int dp[2][maxn][(1<<maxm)],f[maxn][maxn][maxn],pw[maxn<<2][maxn];
int qpow(int x,int k){
    int res=1;
    for(;k;k>>=1){
        if(k&1) res=1ll*res*x%mod;
        x=1ll*x*x%mod;
    }
    return res;
}
int sum(int l,int r){ return pre[r]-pre[max(0,l-1)]; }
int val(int x,int v){
    if(x<=n) return v;
    return 1;
}
int Lagrange(int x0){
    for(int i=1;i<=n+1;i++){
        int up=1,down=1;
        for(int j=1;j<=n+1;j++){
            if(i==j) continue;
            up=1ll*up*(x0-j+mod)%mod;
            down=1ll*down*(i-j+mod)%mod;
        }
        add(ans,1ll*y[i]*up%mod*qpow(down,mod-2)%mod);
    }
    return ans;
}
int main(){
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin>>n>>m>>C; h[0]=1; fac[0]=1;
    if(1ll*C*m<n){ cout<<"0"; return 0; }
    for(int i=1;i<=n-m+1;i++) cin>>b[i];
    for(int i=1;i<=n;i++){
        pre[i]=(pre[i-1]+b[i])%mod;
        h[i]=1ll*h[i-1]*(C-i+1)%mod;
        fac[i]=1ll*fac[i-1]*i%mod;
    } fac[n+1]=1ll*fac[n]*(n+1)%mod; 
    if(2*m>n){
        F=qpow(C,2*m-n); int t=2*m-n;
        m-=t; n-=t;
    }
    if(m==1){
        F=1ll*F*qpow(qpow(pre[n],mod-2),C)%mod;
        ans=1ll*ans*qpow(pre[n],C-n)%mod*h[n]%mod;
        for(int i=1;i<=n;i++) ans=1ll*ans*b[i]%mod;
        cout<<1ll*ans*F%mod; return 0; 
    }
    if(m<=16||n<=30){
        F=1ll*F*qpow(qpow(pre[n],mod-2),C)%mod;
        dp[0][0][0]=1; int p=0,q=1,lim=1<<m-1;
        for(int i=1;i<=n;i++,p^=1,q^=1){
            memset(dp[q],0,sizeof(dp[q]));
            for(int j=0;j<=i-1;j++){
                for(int s=0;s<(1<<m-1);s++){
                    if(!dp[p][j][s]) continue;
                    if(s>>m-2&1){ add(dp[q][j][(s<<1)%lim],1ll*dp[p][j][s]*sum(i-m+1,i-m+1)%mod); continue; }
                    for(int k=0;k<=m-2;k++){
                        if(!(s>>k&1)) continue;
                        add(dp[q][j][s<<1],dp[p][j][s]);
                        add(dp[q][j][(s-(1<<k))<<1],1ll*dp[p][j][s]*sum(i-m+1,i-k-1)%mod);
                    }
                    add(dp[q][j+1][s<<1],1ll*dp[p][j][s]*sum(i-m+1,i)%mod);
                    add(dp[q][j+1][s<<1|1],dp[p][j][s]);
                }
            }
        }
        for(int i=1;i<=min(C,n);i++) add(ans,1ll*dp[p][i][0]*h[i]%mod*qpow(pre[n],C-i)%mod);
        cout<<1ll*ans*F%mod; return 0;
    }
    for(int i=1;i<=n+1;i++){
        pw[i][0]=1;
        for(int j=1;j<=n+2;j++) pw[i][j]=1ll*pw[i][j-1]*b[i]%mod*qpow(j,mod-2)%mod;
    }
    for(int c=1;c<=n+1;c++){//插值的点 
        for(int A=0;A<=c;A++){//c_{2m+1}
            int p=0,q=1; memset(f[p],0,sizeof(f[p]));
            for(int i=0;i+A<=c;i++){//预处理第一列 
                f[p][i][A]=1ll*val(1,i)*val(m+1,c-i-A)%mod*val(2*m+1,A)%mod*pw[1][i]%mod;
            }
            for(int i=2;i<=m;i++){
                memset(f[q],0,sizeof(f[q]));
                for(int j=0;j<=c-A;j++)
                    for(int k=0;k<=A;k++)
                        for(int j2=j;j2<=c-A;j2++)
                            add(f[q][j2][k],1ll*f[p][j][k]*pw[i][j2-j]%mod);
                memset(f[p],0,sizeof(f[p]));
                for(int j=0;j<=c-A;j++)
                    for(int k=0;k<=A;k++)
                        for(int k2=0;k2<=k;k2++)
                            add(f[p][j][k2],1ll*f[q][j][k]*pw[m+i][k-k2]%mod*val(i,j)%mod*val(m+i,c-j-k2)%mod*val(2*m+i,k2)%mod);
            }
            for(int j=0;j<=c-A;j++)
                for(int k=0;k<=A;k++)
                    add(y[c],1ll*f[p][j][k]*pw[m+1][c-j-A]%mod*pw[2*m+1][k]%mod); 
        }
        y[c]=1ll*y[c]*fac[c]%mod; y[c]=1ll*y[c]*qpow(qpow(pre[n],mod-2),c)%mod;
    }
    cout<<1ll*F*Lagrange(C)%mod;
    return 0;
}