P8737 [蓝桥杯 2020 国 B] 质数行者 题解

· · 题解

题意

(1,1,1) 出发,一次只能走质数格且不经过 (r_1.c_1.h_1)(r_2,c_2,h_2),求到达 (n,m,w) 的方案数,答案对 1000000007(即 10^9+7)取模。

思路

看到题目限制不能经过给定的两个点,自然想到容斥,我们记 f(A,B) 为从 A(x_1,y_1,z_1) 到达 B(x_2,y_2,z_2) 的方案数 (x_1 \le x_2,y_1 \le y_2,z_1 \le z_2),记 S(1,1,1), P_1(r_1,c_1,h_1), P_2(r_2,c_2,h_2), T(n,m,w),那么有:

ans=f(S,T)-f(S,P_1) \times f(P_1,T)-f(S,P_2) \times f(P_2,T)-f(S,P_1) \times f(P_1,P_2) \times f(P_2,T)

下面考虑 f(A,B) 如何计算。

只能走质数格的条件比较麻烦,我们先来考虑没有这一条件如何求解。

类似于二维的情况,记 x=x_2-x_1, y=y_2-y_1, z=z_2-z_1 ,那么有

f(A,B) = C_{x+y+z}^z C_{x+y}^y

其意义为:一共走 x+y+z 步,需要向上走 z 步,方案数为 C_{x+y+z}^z;需要向东走 y 步,由于向上已经考虑过了,所以方案数为 C_{x+y}^y,需要向南走 x 步,由于向上和向东已经考虑过了,所以方案数为 C_x^x=1.

现在回到走质数格的限制,观察上面的式子,可以发现我们不关心走的步伐(即走一次的长度),而是次数(即这一段多少次走完) ,也即是说我们要求出某一段长用一些次数走完的方案数。

所以记 g[len][cnt] 表示使用 cnt 个质数走 len 的长度的方案数,则有

g[len][cnt] = \sum_{p \in Prime}^{p \le len} g[len-p][cnt-1]

初态: g[0][0]=1

依据此,我们得到:

f(A,B) = \sum_{i=0}^{\lfloor \frac{x}{2}\rfloor} \sum_{j=0}^{\lfloor \frac{y}{2}\rfloor} \sum_{k=0}^{\lfloor \frac{z}{2}\rfloor} C_{i+j+k}^kC_{i+j}^j \times g[x][i] \times g[y][j] \times g[z][k]

时间复杂度 O(nmw),可以得到 60pts

继续优化,发现瓶颈主要在枚举 i,j,k 上,如果可以在枚举 i 的时候直接计算 j 多好。

当然可以,记 sum=i+j,则有

f(A,B) &= \sum_{i=0}^{\lfloor \frac{x}{2}\rfloor} \sum_{j=0}^{\lfloor \frac{y}{2}\rfloor} \sum_{k=0}^{\lfloor \frac{z}{2}\rfloor} \frac{(i+j+k)!}{i! \times j! \times k!} \times g[x][i] \times g[y][j] \times g[z][k] \\ &= \sum_{i=0}^{\lfloor \frac{x}{2}\rfloor} \sum_{j=0}^{\lfloor \frac{y}{2}\rfloor} \sum_{k=0}^{\lfloor \frac{z}{2}\rfloor} (i+j+k)! \times \frac{g[x][i]}{i!} \times \frac{g[y][j]}{j!} \times \frac{g[z][k]}{k!} \\ &= \sum_{sum=0}^{sum \le \lfloor \frac{x}{2}\rfloor + \lfloor \frac{y}{2}\rfloor} \sum_{i=0}^{sum} \frac{g[x][i]}{i!} \times \frac{g[y][sum-i]}{(sum-i)!} \sum_{k=0}^{\lfloor \frac{z}{2}\rfloor} (sum+k)! \times \frac{g[z][k]}{k!} \end{aligned}

时间复杂度 O(n(m+w)),可以得到 100pts

Code

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
const int N=3e3+10;
const int mod=1e9+7;
int pr[]={0,
2,3,5,7,11,13,17,19,23,29,
31,37,41,43,47,53,59,61,67,71,
73,79,83,89,97,101,103,107,109,113,
127,131,137,139,149,151,157,163,167,173,
179,181,191,193,197,199,211,223,227,229,
233,239,241,251,257,263,269,271,277,281,
283,293,307,311,313,317,331,337,347,349,
353,359,367,373,379,383,389,397,401,409,
419,421,431,433,439,443,449,457,461,463,
467,479,487,491,499,503,509,521,523,541,
547,557,563,569,571,577,587,593,599,601,
607,613,617,619,631,641,643,647,653,659,
661,673,677,683,691,701,709,719,727,733,
739,743,751,757,761,769,773,787,797,809,
811,821,823,827,829,839,853,857,859,863,
877,881,883,887,907,911,919,929,937,941,
947,953,967,971,977,983,991,997,
},cnt=168;// 1000 以内的质数
ll g[N][N],M;
ll fac[N],invfac[N];
struct rec{
    int x,y,z;
}s,p1,p2,t;
ll expow(ll a,ll b){
    ll res=1;
    while(b){
        if(b&1) res=res*a%mod;
        b>>=1;a=a*a%mod;
    }
    return res;
}
void init(int n){
    // 预处理 g
    g[0][0]=1;
    for(int i=1;i<=n;i++){// 距离
        for(int j=1;j<=i/2;j++){// 个数
            for(int k=1;k<=cnt&&pr[k]<=i;k++){// 每个质数
                g[i][j]=(g[i][j]+g[i-pr[k]][j-1])%mod;
            }
        }
    }
    // 预处理阶乘
    n*=3;// i+j+k<3000
    fac[0]=1;
    for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
    invfac[n]=expow(fac[n],mod-2);
    for(int i=n-1;i>=0;i--) invfac[i]=invfac[i+1]*(i+1)%mod;
}
ll f(rec a,rec b){
    ll res=0;
    int x=b.x-a.x,y=b.y-a.y,z=b.z-a.z;
    if(x<0||y<0||z<0) return res;// 空集
    for(int sum=0;sum<=x/2+y/2;sum++){
        ll ans1=0;
        for(int i=0;i<=sum;i++){
            int j=sum-i;
            if(i<=x/2&&j<=y/2)
                ans1=(ans1+(g[x][i]*invfac[i]%mod)*(g[y][j]*invfac[j]%mod))%mod;
        }
        ll ans2=0; 
        for(int k=0;k<=z/2;k++) ans2=(ans2+g[z][k]*invfac[k]%mod*fac[sum+k])%mod;
        res=(res+ans1*ans2)%mod;
    }
    return res;
}
int n,m,w;
int r1,c1,h1;
int r2,c2,h2;
int main(){
    scanf("%d %d %d",&n,&m,&w);
    scanf("%d %d %d",&r1,&c1,&h1);
    scanf("%d %d %d",&r2,&c2,&h2);
    s={1,1,1};
    p1={r1,c1,h1};
    p2={r2,c2,h2};
    t={n,m,w};
    init(max(max(n,m),w));
    if(r2<=r1&&c2<=c1&&h2<=h1) swap(p1,p2);
    ll ans1=f(s,t);
    ll ans2=f(s,p1)*f(p1,t)%mod;
    ll ans3=f(s,p2)*f(p2,t)%mod;
    ll ans4=sf(s,p1)*f(p1,p2)%mod*f(p2,t)%mod;
    printf("%lld\n",((ans1-ans2-ans3+ans4)%mod+mod)%mod);// 容斥
    return 0;
}