题解:P11420 [清华集训 2024] 乘积的期望
Trick:对于定长度区间操作经常可以转化为网格图平衡复杂度,印象中 QOJ4893. Imbalance 也是的。
考虑将
可以发现每次选择
有两种转移方式,一种是从上往下行间转移,另一种是从左往右在列上面进行转移。如果我们能找到两种转移方式就能平衡复杂度。
首先,将所有元素乘积的期望转化成组合意义,也就是对于所有覆盖方案我们都对于每个位置选择一种覆盖其的操作的总方案数。我们可以对着这个进行 DP。
竖向 DP
思考一下 DP 需要记录哪些状态,首先有记录当前的位置
这样子我们就可以根据信息开闭线段了。同时,我们发现上述过程并没有对于所有线段作出区分,于是记录一下我们已经选了的线段数量
设计好了状态,现在需要进行 dp。考虑当前这一个位置
首先,如果
否则,我们可以自由选择。第一个选择就是在
如果不选择新开线段,我们可以为其挑选一个之前开的线段,钦定其覆盖
分析一下时间复杂度,状态数
横向 DP
上述做法可以通过
尝试观察一些刻画方式,可以发现每次操作必然是对于每一列的三个位置中恰好有一个位置被操作。所以
还有就是基本上所有线段都是跨越两行的,这会导致交界处被覆盖次数很多。所以第一行中的
这两个条件足够了吗?并不是的。还有一个约束就是你会发现,要么一次操作是行内的,要么一二联动,要么二三联动,唯独没有一三联动的情况,所以
可以证明这三个条件已经是充要条件了。可以对于这个结构按照列进行 DP。
对于第三个约束,我们直接外层枚举
转移的时候,枚举
有一个小 trick 就是,你发现
发现瓶颈在于
这部分的时间复杂度是
两个做法综合起来可以通过本题。
#include<bits/stdc++.h>
#define pb emplace_back
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef long long ll;
const int maxm=17;
const int maxn=60;
const int mod=998244353;
void add(int &x,int y){ x=x+y>=mod?x+y-mod:x+y; }
void sub(int &x,int y){ x=x<y?x-y+mod:x-y; }
void cmax(int &x,int y){ x=x>y?x:y; }
void cmin(int &x,int y){ x=x<y?x:y; }
int b[maxn],pre[maxn],fac[maxn],h[maxn],y[maxn],n,m,C,F=1,ans=0;
int dp[2][maxn][(1<<maxm)],f[maxn][maxn][maxn],pw[maxn<<2][maxn];
int qpow(int x,int k){
int res=1;
for(;k;k>>=1){
if(k&1) res=1ll*res*x%mod;
x=1ll*x*x%mod;
}
return res;
}
int sum(int l,int r){ return pre[r]-pre[max(0,l-1)]; }
int val(int x,int v){
if(x<=n) return v;
return 1;
}
int Lagrange(int x0){
for(int i=1;i<=n+1;i++){
int up=1,down=1;
for(int j=1;j<=n+1;j++){
if(i==j) continue;
up=1ll*up*(x0-j+mod)%mod;
down=1ll*down*(i-j+mod)%mod;
}
add(ans,1ll*y[i]*up%mod*qpow(down,mod-2)%mod);
}
return ans;
}
int main(){
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin>>n>>m>>C; h[0]=1; fac[0]=1;
if(1ll*C*m<n){ cout<<"0"; return 0; }
for(int i=1;i<=n-m+1;i++) cin>>b[i];
for(int i=1;i<=n;i++){
pre[i]=(pre[i-1]+b[i])%mod;
h[i]=1ll*h[i-1]*(C-i+1)%mod;
fac[i]=1ll*fac[i-1]*i%mod;
} fac[n+1]=1ll*fac[n]*(n+1)%mod;
if(2*m>n){
F=qpow(C,2*m-n); int t=2*m-n;
m-=t; n-=t;
}
if(m==1){
F=1ll*F*qpow(qpow(pre[n],mod-2),C)%mod;
ans=1ll*ans*qpow(pre[n],C-n)%mod*h[n]%mod;
for(int i=1;i<=n;i++) ans=1ll*ans*b[i]%mod;
cout<<1ll*ans*F%mod; return 0;
}
if(m<=16||n<=30){
F=1ll*F*qpow(qpow(pre[n],mod-2),C)%mod;
dp[0][0][0]=1; int p=0,q=1,lim=1<<m-1;
for(int i=1;i<=n;i++,p^=1,q^=1){
memset(dp[q],0,sizeof(dp[q]));
for(int j=0;j<=i-1;j++){
for(int s=0;s<(1<<m-1);s++){
if(!dp[p][j][s]) continue;
if(s>>m-2&1){ add(dp[q][j][(s<<1)%lim],1ll*dp[p][j][s]*sum(i-m+1,i-m+1)%mod); continue; }
for(int k=0;k<=m-2;k++){
if(!(s>>k&1)) continue;
add(dp[q][j][s<<1],dp[p][j][s]);
add(dp[q][j][(s-(1<<k))<<1],1ll*dp[p][j][s]*sum(i-m+1,i-k-1)%mod);
}
add(dp[q][j+1][s<<1],1ll*dp[p][j][s]*sum(i-m+1,i)%mod);
add(dp[q][j+1][s<<1|1],dp[p][j][s]);
}
}
}
for(int i=1;i<=min(C,n);i++) add(ans,1ll*dp[p][i][0]*h[i]%mod*qpow(pre[n],C-i)%mod);
cout<<1ll*ans*F%mod; return 0;
}
for(int i=1;i<=n+1;i++){
pw[i][0]=1;
for(int j=1;j<=n+2;j++) pw[i][j]=1ll*pw[i][j-1]*b[i]%mod*qpow(j,mod-2)%mod;
}
for(int c=1;c<=n+1;c++){//插值的点
for(int A=0;A<=c;A++){//c_{2m+1}
int p=0,q=1; memset(f[p],0,sizeof(f[p]));
for(int i=0;i+A<=c;i++){//预处理第一列
f[p][i][A]=1ll*val(1,i)*val(m+1,c-i-A)%mod*val(2*m+1,A)%mod*pw[1][i]%mod;
}
for(int i=2;i<=m;i++){
memset(f[q],0,sizeof(f[q]));
for(int j=0;j<=c-A;j++)
for(int k=0;k<=A;k++)
for(int j2=j;j2<=c-A;j2++)
add(f[q][j2][k],1ll*f[p][j][k]*pw[i][j2-j]%mod);
memset(f[p],0,sizeof(f[p]));
for(int j=0;j<=c-A;j++)
for(int k=0;k<=A;k++)
for(int k2=0;k2<=k;k2++)
add(f[p][j][k2],1ll*f[q][j][k]*pw[m+i][k-k2]%mod*val(i,j)%mod*val(m+i,c-j-k2)%mod*val(2*m+i,k2)%mod);
}
for(int j=0;j<=c-A;j++)
for(int k=0;k<=A;k++)
add(y[c],1ll*f[p][j][k]*pw[m+1][c-j-A]%mod*pw[2*m+1][k]%mod);
}
y[c]=1ll*y[c]*fac[c]%mod; y[c]=1ll*y[c]*qpow(qpow(pre[n],mod-2),c)%mod;
}
cout<<1ll*F*Lagrange(C)%mod;
return 0;
}