题解 P4007 【小 Y 和恐怖的奴隶主】

· · 题解

想了两天,终于搞出来了……

首先一个普通的DP应该是很简单的。我们设 f_{i,j,k,l} 表示进行 i 步后,有 j 个三血,k 个两血,l 个一血的概率

我们考虑推式子,就会发现一些性质:

  1. 无论从 i-1 推至 i 还是从 i 推至 i+1,能够互相转移的对 \{j,k,l\}\rightarrow\{j',k',l'\} 总是固定的;

  2. 所有合法的 \{j,k,l\} 的个数较少(最多 165 对)。

于是我们考虑矩阵优化。

这里我们发现,如果把第 i 个时刻的期望也作为一个状态加到矩阵里面去,转移起来将会很方便——

但是无奈我太弱了,没想到,于是便有了下面一种解法:

我们令 A 为状态转移的 n\times n 矩阵,p 为一 1\times n 矩阵,表示初始概率(即,除了 \{1,0,0\} 态为 1,其余位置都为 0 的矩阵),f 为一 n\times1 矩阵,表示从某个状态出发,有多大的概率进行一次攻击并打到Boss。

我们有 i 时刻的概率向量即为

p\times A^i

而对于一概率向量 v,其总打到Boss的概率是 v\times f(实际上该式子的结果是一 1\times1 矩阵,但我们此处把它看作一个数)。

于是我们就有 i 时刻进行一次攻击打到 Boss 的概率是

p\times A^i\times f

(注意乘法的顺序不可颠倒)

则我们最终的答案就是

\sum\limits_{i=0}^{n-1}p\times A^i\times f

因为矩阵乘法具有左结合律和右结合律,所以它可以被转为

p\times\Big(\sum\limits_{i=0}^{n-1}A^i\Big)\times f

于是我们考虑如何求出 \sum\limits_{i=0}^{n-1}A^i

显然这是一个等比矩阵求和的形式,可以类似等比数列求和的形式套用 \dfrac{A^n-I}{A-I};但是很遗憾,此处的 A-I 并非可逆矩阵,所以我们不得不放弃这个讨巧的办法。

于是我们只能考虑类似于矩阵快速幂地处理。道理很简单,我们可以对所有 2 的次幂预处理出来 \sum\limits_{i=0}^{n-1}A^i,然后拼接就能求出任意 n 的答案,具体可以参考代码,复杂度同矩阵快速幂。

但是我们稍微估算一下就会发现,这有问题——总复杂度为 O(165^3T\log n ),似乎不可能通过。

这时我们就想到今年NOI好像也出现了一道矩阵快速幂的题,但是正解中它里面直接求的就是类似 pA^n 的东西,复杂度是 N^2 的。

于是我们考虑求出 p\times\Big(\sum\limits_{i=0}^{n-1}A^i\Big)。因为全过程只需要计算一个 1\times n 的矩阵乘上一个 n\times n 的矩阵,所以单次乘法的复杂度就变成了 165^2

于是总复杂度就变成了 O(165^2T\log n ),可以通过洛谷数据(但是无法通过LOJ数据,因为懒得卡常了)。

通过模板化矩阵和向量(即是 1\times n 矩阵),代码可以做到极其整洁。

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int T,m,lim,N,inv[100];
int id[9][9][9];
int ksm(int x,int y){
    int z=1;
    for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)z=1ll*z*x%mod;
    return z;
}
struct Matrix{
    int a[166][166];
    Matrix(){memset(a,0,sizeof(a));}
    int* operator [](int x){return a[x];}
    friend Matrix operator *(Matrix &x,Matrix &y){
        Matrix z;
        for(int i=1;i<=N;i++)for(int j=1;j<=N;j++)for(int k=1;k<=N;k++)(z[i][j]+=1ll*x[i][k]*y[k][j]%mod)%=mod;
        return z;
    }
}M[70];
struct Vector{
    int a[166];
    Vector(){memset(a,0,sizeof(a));}
    int& operator [](int x){return a[x];}
    friend Vector operator *(Vector &x,Matrix &y){
        Vector z;
        for(int i=1;i<=N;i++)for(int j=1;j<=N;j++)(z[j]+=1ll*x[i]*y[i][j]%mod)%=mod;
        return z;
    }
    friend int operator *(Vector &x,Vector &y){
        int z=0;
        for(int i=1;i<=N;i++)(z+=1ll*x[i]*y[i]%mod)%=mod;
        return z;
    }
    friend Vector operator +(Vector x,Vector y){
        Vector z;
        for(int i=1;i<=N;i++)z[i]=(x[i]+y[i])%mod;
        return z;
    }
}S[70],F;
long long n;
int Calc(){
    Vector R;
    for(int i=0;i<60;i++){
        if(!((n>>i)&1))continue;
        R=R*M[i];
        R=R+S[i];
    }
    return R*F;
}
int main(){
    scanf("%d%d%d",&T,&m,&lim);
    for(int i=0;i<=lim+1;i++)inv[i]=ksm(i,mod-2);
    Matrix &B=M[0];
    if(m==1){
        for(int i=0;i<=lim;i++)id[i][0][0]=++N,F[N]=inv[i+1];
        for(int i=0;i<=lim;i++){
            if(i>0)B[id[i][0][0]][id[i-1][0][0]]=1ll*i*inv[i+1]%mod;
            B[id[i][0][0]][id[i][0][0]]=inv[i+1];
        }
    }
    if(m==2){
        for(int i=0;i<=lim;i++)for(int j=0;i+j<=lim;j++)id[i][j][0]=++N,F[N]=inv[i+j+1];
        for(int i=0;i<=lim;i++)for(int j=0;i+j<=lim;j++){
            if(i)B[id[i][j][0]][id[i+(i+j!=lim)-1][j+1][0]]=1ll*i*inv[i+j+1]%mod;
            if(j)B[id[i][j][0]][id[i][j-1][0]]=1ll*j*inv[i+j+1]%mod;
            B[id[i][j][0]][id[i][j][0]]=inv[i+j+1];
        }
    }
    if(m==3){
        for(int i=0;i<=lim;i++)for(int j=0;i+j<=lim;j++)for(int k=0;i+j+k<=lim;k++)id[i][j][k]=++N,F[N]=inv[i+j+k+1];
        for(int i=0;i<=lim;i++)for(int j=0;i+j<=lim;j++)for(int k=0;i+j+k<=lim;k++){
            if(i)B[id[i][j][k]][id[i+(i+j+k!=lim)-1][j+1][k]]=1ll*i*inv[i+j+k+1]%mod;
            if(j)B[id[i][j][k]][id[i+(i+j+k!=lim)][j-1][k+1]]=1ll*j*inv[i+j+k+1]%mod;
            if(k)B[id[i][j][k]][id[i][j][k-1]]=1ll*k*inv[i+j+k+1]%mod;
            B[id[i][j][k]][id[i][j][k]]=inv[i+j+k+1];
        }
    }
//  for(int i=1;i<=N;i++){for(int j=1;j<=N;j++)printf("%9d ",B[i][j]);puts("");}
    for(int i=1;i<60;i++)M[i]=M[i-1]*M[i-1];

    S[0][id[1][0][0]]=1;
    for(int i=1;i<60;i++)S[i]=(S[i-1]*M[i-1])+S[i-1];

    while(T--)scanf("%lld",&n),printf("%d\n",Calc());
    return 0;
}