devinwang is watching you
提供一个
显然 dp,用三维分别表示当前位置,当前连续段数,当前是否为
代码中分别叫 dp 和 qwq。
对于一个状态,我们发现它最多由两个先前状态推出来。
我们算出这个状态的概率以及由每个状态转移过来的概率。
这样就可以算了。
具体的 dp 式子见代码。
代码在下面。由于中间的 Mod 写的很丑,就扔 这里.
#include <stdio.h>
inline int read()
{
int num=0;char c=getchar();
while(c<48||c>57)c=getchar();
while(c>47&&c<58)num=(num<<3)+(num<<1)+(c^48),c=getchar();
return num;
}
#define int long long
const int mod = 998244353;
int qp(int x,int p)
{
int res=1;
while(p)
{
if(p&1)res=res*x%mod;
x=x*x%mod;
p>>=1;
}
return res;
}
struct Mod{...};
Mod a[100005],p[100005],dp[100005][20][2],qwq[100005][20][2];
signed main()
{
int n=read(),m=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<=n;i++)p[i]=read();
//memset(dp,-0x3f,sizeof(dp));
qwq[0][0][0]=1;
Mod unit=1;
for(int i=1;i<=n;i++)
{
for(int k=0;k<=m;k++)
{
if(k==0)
{
qwq[i][k][0]=qwq[i-1][k][0]*(mod+1-p[i].x);
goto hidsa;
}
if(k+k-1>i)continue;
else if(k+k-1==i)
{
qwq[i][k][1]=qwq[i-1][k-1][0]*p[i];
//dp[i][k][1]=qwq[i][k][1]*a[i]+dp[i-1][k-1][0];
dp[i][k][1]=a[i]+dp[i-1][k-1][0];
}
else
{
qwq[i][k][1]=qwq[i-1][k-1][0]*p[i]+qwq[i-1][k][1]*p[i];
Mod wqw=qwq[i-1][k-1][0]+qwq[i-1][k][1];
wqw=unit/wqw;
dp[i][k][1]=a[i];//dp[i-1][k-1][0]*qwq[i-1][k-1][0]
dp[i][k][1]+=dp[i-1][k-1][0]*qwq[i-1][k-1][0]*wqw;
dp[i][k][1]+=dp[i-1][k][1]*qwq[i-1][k][1]*wqw;
//dp[i][k][1]=qwq[i][k][1]*a[i]+dp[i-1][k-1][0]+dp[i-1][k][1];
qwq[i][k][0]=(qwq[i-1][k][0]+qwq[i-1][k][1])*(mod+1-p[i].x);
//dp[i][k][0]=qwq[i][k][0]*(dp[i-1][k][0]+dp[i-1][k][1]);
wqw=qwq[i-1][k][0]+qwq[i-1][k][1];
wqw=unit/wqw;
dp[i][k][0]=dp[i-1][k][0]*qwq[i-1][k][0]*wqw;
dp[i][k][0]+=dp[i-1][k][1]*qwq[i-1][k][1]*wqw;
}
/* if(k*2-1>i)continue;
if(k*2-1!=i){
dp[i][k][0]+=dp[i-1][k][1]*(mod-p[i].x+1);
dp[i][k][0]+=dp[i-1][k][0]*(mod-p[i].x+1);}
//dp[i][k][1]=p[i]*a[i];
dp[i][k][1]+=(a[i]+dp[i-1][k-1][0]+dp[i-1][k][1])*p[i];*/
hidsa:;/*
printf("dp[ %lld %lld %lld] = %lld\n",i,k,0ll,dp[i][k][0].x);
printf("dp[ %lld %lld %lld] = %lld\n",i,k,1ll,dp[i][k][1].x);
printf("qwq[ %lld %lld %lld] = %lld\n",i,k,0ll,qwq[i][k][0].x);
printf("qwq[ %lld %lld %lld] = %lld\n",i,k,1ll,qwq[i][k][1].x);*/
}
}
//printf("%lld\n",((dp[n][m][0]+dp[n][m][1])/qp(2,n)*(qp(2,n)-1)).x);
printf("%lld\n",(dp[n][m][0]*qwq[n][m][0]+dp[n][m][1]*qwq[n][m][1]).x);
}
/*
2 1
3 5
0 1
*/