题解 P4457 【[BJOI2018]治疗之雨】

· · 题解

贡献一个不用高斯消元来做(但实际上差不多只不过解决问题的方向不一样)的题解吧。

首先k次中有i次正好对自己造成了伤害的概率是\frac{C_k^im^{k-i}}{(m+1)^k}

那么依靠这个我们就可以非常简单的预处理一个p[i][j]表示当前血量为i,经过了一轮之后血量变成了j的概率。

然后就考虑设f[i]为血量为i时扣血扣成0的期望步数。

那么转移就非常显然:

f[0]=0 f[1]=p[1][0]\times f[0]+p[1][1]\times f[1]+p[1][2]\times f[2]+1 f[2]=p[2][0]\times f[0]+p[2][1]\times f[1]+p[2][2]\times f[2]+p[2][3]\times f[3]+1
               ........
f[i]=1+\sum_{j=0}^{i+1}p[i][j]\times f[j]

那么显然这个可以高斯消元,但是复杂度为O(n^3),爆炸。

考虑看看这个转移式有什么性质。

可以发现,第i个方程的变量是1i+1(第n个除外),那么可以考虑在第i个等式将f[i+1]表示成a\times f[1]+b这样的形式(i\in [1,n-1]),然后再将它们代入第n个等式中解出f[1],然后再代入求出f[p]即可。

这样的复杂度是O(n^2)的,总复杂度就是O(Tn^2)

需要特判m=0k=0的情况。

code:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<iostream>
#include<set>
#include<vector>
#include<queue>
#include<stack>
#include<bitset>
#define G 3
#define eps 1e-15
#define maxn 1505
#define maxm 100010
#define inf 999999999999999
#define mod 1000000007
#define mp(x,y) make_pair(x,y)
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef unsigned int uint;
typedef pair<int,int> pii;
int read()
{
    int x=0,f=1;
    char ch=getchar();
    while(ch-'0'<0||ch-'0'>9){if(ch=='-') f=-1;ch=getchar();}
    while(ch-'0'>=0&&ch-'0'<=9){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int T;
int n,px,m,K,INV,INVX;
int C[maxn],p[maxn][maxn],ci[maxn],cix[maxn];
int quick_pow(int x,int p)
{
    int an=1,po=x;
    while(p)
    {
        if(p&1)  an=1ll*an*po%mod;
        po=1ll*po*po%mod;
        p>>=1;
    }
    return an;
}
int calc(int i)
{
    if(i<0||i>K)  return 0;
    return 1ll*C[i]*ci[i]%mod*INVX%mod;
}
struct P{
    int k,b;
}a[maxn];
P operator * (P a,int b)
{
    return (P){1ll*a.k*b%mod,1ll*a.b*b%mod};
}
P operator + (P a,P b)
{
    return (P){(a.k+b.k)%mod,(a.b+b.b)%mod};
}
P operator - (P a,int b)
{
    return (P){a.k,(a.b-b+mod)%mod};
}
P operator - (P a,P b)
{
    return (P){(a.k-b.k+mod)%mod,(a.b-b.b+mod)%mod};
}
int main()
{
    T=read();
    while(T--)
    {
        n=read();px=read();m=read();K=read();INV=quick_pow(m+1,mod-2);INVX=quick_pow(INV,K);
        if(!K)
        {
            puts("-1");
            continue;
        }
        if(!m)
        {
            if(K==1)  puts("-1");
            else{
                int cnt=0;
                while(px>0)  ++cnt,px=min(n,px+1)-K;
                printf("%d\n",cnt);
            }
            continue;
        }
        for(int i=0;i<=min(n,K);i++)  ci[i]=quick_pow(m,K-i)%mod;
        C[0]=1;for(int i=1;i<=min(n,K);i++)  C[i]=1ll*C[i-1]*(K-i+1)%mod*quick_pow(i,mod-2)%mod;
        for(int i=1;i<=n;i++)
        {
            for(int j=1;j<=n;j++)
            {
                p[i][j]=(1ll*calc(i-j)*INV%mod*m%mod+1ll*calc(i-j+1)*INV%mod)%mod;
                if(i==n)  p[i][j]=calc(i-j);
            }
        }
        a[1]=(P){1,0};
        for(int i=1;i<=n-1;i++)
        {
            a[i+1]=(P){0,0};
            for(int j=1;j<=i;j++)  a[i+1]=a[i+1]+a[j]*p[i][j];
            a[i+1]=(a[i]-a[i+1]-1)*quick_pow(p[i][i+1],mod-2);
        }
        P sum=(P){0,0};
        for(int i=1;i<=n;i++)  sum=sum+a[i]*p[n][i];
        sum=a[n]-sum-1;
        int st=1ll*(mod-sum.b)*quick_pow(sum.k,mod-2)%mod;
        int ans=(1ll*st*a[px].k%mod+a[px].b)%mod;
        printf("%d\n",ans);
    }
    return 0;
}