P8558 黑暗(Darkness)

· · 题解

可以先考虑计算从 (A,B,C) 走到 (0,0,0) 而不撞墙的概率,有递推关系:

f_{A,B,C}=\frac 13 (f_{A-1,B,C}+f_{A,B-1,C}+f_{A,B,C-1})

建立生成函数,设 f_{A,B,C}=[x^Ay^Bz^C]F,则有

F= \frac 13(x+y+z)F+1

多出来的这个 +1 就是因为要补上 f_{0,0,0}=1 这一项。

F = \frac{1}{1-(x+y+z)/3}=\sum_{i=0}^\infty \frac{1}{3^i}(x+y+z)^i

提取系数即得

f_{A,B,C}= \frac{1}{3^{A+B+C}}\binom{A+B+C}{A,B,C}

那么计算从 (A,B,C) 走到 (i,j,k) 而不撞墙概率就容易计算了。其递推式与 f 很像,再稍加分析就可以发现所求是 f_{A-i,B-j,C-k}

现在只需要枚举墙边的位置,根据走到那撞到墙的概率,就可以将答案表示为:

g(A,B,C)+g(A,C,B)+g(B,C,A) g(d_1,d_2,d_3)=\frac 13\sum_{i=0}^{d_1}\sum_{j=0}^{d_2} \binom{i+j+d_3}{i,j,d_3}\frac{(d_1+d_2-i-j)^k}{3^{i+j+d_3}}

你会发现有些位置可能被重复计算了,但没关系。被计了两次的位置,都是有两个方向面对墙的位置。

这里就说 g(A,B,C) 的计算方法,剩下的两部分也是一样的。可以发现 i+j 在和式中频繁出现,考虑换元枚举 t=i+j

3g(A,B,C)=\frac{1}{3^C}\sum_{t=0}^{A+B} \frac{(A+B-t)^k}{3^t}\sum_{j=t-A}^B \binom{t+C}{t-j,j,C}

后面那个和式可以单独计算,先化简为

\binom{t+C}{C} \sum_{j=t-A}^B \binom{t}{j}

这种超几何式显然可以整式递推计算,利用二项式系数的递推公式:

D_n=\sum_{i=n-A}^B \binom ni = \sum_{i=n-A}^B \binom{n-1}{i}+\binom{n-1}{i-1} D_n=2D_{n-1} -\binom{n-1}{A}- \binom{n-1}{B}

可以在线性时间内递推计算,最后用线性筛来求 t^k 即可。

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 10000003
#define ll long long
#define p 998244353
using namespace std;

inline int power(int a,int t){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%p;
        a = (ll)a*a%p;
        t >>= 1;
    }
    return res;
}

const int inv3 = 332748118;
int fac[N],ifac[N],pw[N],pr[N>>3],inv[N];
int cnt;

void init(int n,int k){
    inv[1] = pw[1] = fac[0] = fac[1] = ifac[0] = ifac[1] = 1;
    for(int i=2;i<=n;++i) fac[i] = (ll)fac[i-1]*i%p;
    ifac[n] = power(fac[n],p-2);
    for(int i=n-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
    for(int i=2;i<=n;++i){
        inv[i] = (ll)fac[i-1]*ifac[i]%p;
        if(!pw[i]){
            pr[++cnt] = i;
            pw[i] = power(i,k);
        }
        for(int j=1;j<=cnt&&i*pr[j]<=n;++j){
            pw[i*pr[j]] = (ll)pw[i]*pw[pr[j]]%p;
            if(i%pr[j]==0) break;
        }
    }
}

int g(int A,int B,int C,int k){
    if(A>B) swap(A,B);
    ll res = 0;
    int bc = 1,f = 1,ifa = ifac[A],ifb = ifac[B],ipw3 = 1;
    for(int i=0;i<=A+B;++i){
        res += (ll)pw[A+B-i]*f%p*ipw3%p*bc%p;
        ipw3 = ipw3*332748118ll%p;
        bc = (ll)bc*(i+C+1)%p*inv[i+1]%p;
        if(i<A) f = (f<<1)>=p?(f<<1)-p:f<<1;
        else if(i<B) f = ((f<<1)-(ll)fac[i]*ifa%p*ifac[i-A])%p;
        else f = ((f<<1)-((ll)ifa*ifac[i-A]+(ll)ifb*ifac[i-B])%p*fac[i])%p;
    }
    return res%p*power(3,p-C-1)%p;
}

int A,B,C,k;

int main(){
    scanf("%d%d%d%d",&A,&B,&C,&k);
    init(A+B+C-min(min(A,B),C),k);
    int ans = (ll)((g(A,B,C,k)+g(A,C,B,k))%p+g(B,C,A,k))*inv3%p;
    printf("%d",(ans+p)%p);
    return 0;  
}