P9551 「PHOI-1」斗之魂 题解

· · 题解

这是本蒟蒻第三十三次写的题解,如有错误点请好心指出!

显然如果小 X 用第 1 种方式击败第 i 个 BOSS,当 k_{i,0} 确定时,k_{i,1}k_{i,2} 也是确定的,考虑小 X 用第 2 种方式击败第 i 个 BOSS 的情况,先对 \dfrac{1}{k_{i,0}}=\dfrac{1}{k_{i,1}}+\dfrac{1}{k_{i,2}} 中右边式子进行通分,则

\dfrac{1}{k_{i,0}}=\dfrac{k_{i,1}+k_{i,2}}{k_{i,1}k_{i,2}}

去分母得:

\begin{aligned} k_{1,0}(k_{i,1}+k_{i,2})&=k_{i,1}k_{i,2}\\ 0&=-k_{1,0}(k_{i,1}+k_{i,2})+k_{i,1}k_{i,2} \end{aligned}

两边各加上一个 k_{1,0}^2 得:

k_{i,0}^2=k_{i,0}^2-k_{1,0}(k_{i,1}+k_{i,2})+k_{i,1}k_{i,2}

发现右边式子是个完全平方公式,因式分解得:

k_{i,0}^2=(k_{i,0}-k_{i,1})(k_{i,0}-k_{i,2})

因为 k_{i,0},k_{i,1},k_{i,2} 均为正整数且 k_{i,1},k_{i,2} 一定大于 k_{i,0},所以当 k_{i,0} 确定时,k_{i,1}k_{i,2} 的方案数为 k_{i,0}^2 的因子个数。

k_{i,0} 质因数分解成 p_1^{\alpha_1} \times p_2^{\alpha_2} \times p_3^{\alpha_3} \cdots p_x^{\alpha_x},则 k_{i,0}^2 的因子个数为 (2\alpha_1+1) \times (2\alpha_2+1) \times (2\alpha_3+1)\cdots(2\alpha_x+1),用线性筛 O(m) 维护最小质因子次数并求出。

预处理好因子个数之后,就可以用 dp 求方案数了。设 f_{i,j} 为当前击败第 i 个 BOSS,获得的总稀有金属个数为 j 的方案数,最终答案为 f_{n,m},则当 b_i=1 时,转移方程为

f_{i,j}=\sum_{k=1}^{j-1}f_{i-1,k}

b_i=2 时,转移方程为

f_{i,j}=\sum_{k=1}^{j-1}f_{i-1,k} \times a_{(j-k)^2}

其中 a_{(j-k)^2}(j-k)^2 的因子个数。

发现转移方程均与击败 BOSS 的顺序无关,我们可以先全部处理第 1 种转移方程,再全部处理第 2 种转移方程,其中第 1 种转移方程可以用组合数学计算,我们直接考虑第 2 种转移方程,记 g_i=a_{i^2},则转移方程为

f_{i,j}=\sum_{k=1}^{j-1}f_{i-1,k} \times g_{j-k}

设状态函数 F_t(x),转移函数 G(x),答案函数 A(x),使

F_t(x)=\sum_{i=0}^\infty f_{t,i}x^i G(x)=\sum_{i=1}^{j-1}g_{i}x^i A(x)=\sum_{i=0}^\infty \sum_{j=0}^k f_{t,i} g_{j} x^i

其中 F_0(x)=1,则

\begin{aligned} F_t(x)&=F_{t-1}(x)G(x)\\ &=G^t(x) \end{aligned}

然后答案函数为

\begin{aligned} A(x)&=F_t(x)\\ &=G^t(x)\\ &=\exp(t \ln G(x)) \end{aligned}

用多项式快速幂即可做到时间复杂度为 O(m \log m+q)

#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
typedef long long ll;
const ll mod=998244353,gg=3,ggi=(mod+1)/3;
ll yz[250005]={0,1},zs[250005],tot[250005],cnt;
int inv[600005]={0,1},jc[600005]={1},ni[600005];
int n,q,mx,m[100005],bl[600005],F[600005],G[600005],G1[600005],G2[600005],b1[600005],c1[600005],d1[600005],e1[600005],lim,ss,cnt1,cnt2;
bool bz[250005];
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
inline void ycl()
{
    for(int i=2;i<=mx;i++)
    {
        if(!bz[i])
        {
            zs[++cnt]=i;
            tot[i]=1;
            yz[i]=3;
        }
        for(int j=1;j<=cnt;j++)
        {
            if(i*zs[j]>mx) break;
            bz[i*zs[j]]=1;
            if(i%zs[j]==0)
            {
                tot[i*zs[j]]=tot[i]+1;
                yz[i*zs[j]]=yz[i]/(2*tot[i]+1)*(2*tot[i]+3);
                break;
            }
            tot[i*zs[j]]=1;
            yz[i*zs[j]]=yz[i]*3;
        }
    }
}
inline int ksm(int x,int y)
{
    int res=1;
    while(y)
    {
        if(y&1) res=(ll)res*x%mod;
        x=(ll)x*x%mod;
        y>>=1;
    }
    return res;
}
void NTT(int *A,int type)
{
    for(int i=0;i<lim;i++)
    if(i<bl[i]) swap(A[i],A[bl[i]]);
    for(int i=2;i<=lim;i<<=1)
    {
        int mid=i>>1;
        ll gn=ksm(type==1?gg:ggi,(mod-1)/i);
        for(int j=0;j<lim;j+=i)
        {
            ll mi=1;
            for(int k=j;k<j+mid;k++,mi=mi*gn%mod)
            {
                ll ax=A[k],ay=(ll)mi*A[k+mid]%mod;
                A[k]=(ax+ay)%mod;
                A[k+mid]=(ax-ay+mod)%mod;
            }
        }
    }
    if(type==0)
    {
        ll inv=ksm(lim,mod-2);
        for(int i=0;i<lim;i++) A[i]=(ll)A[i]*inv%mod;
    }
}
void init(int len)
{
    lim=1,ss=0;
    while(lim<=len) lim<<=1,ss++;
    for(int i=0;i<lim;i++) bl[i]=(bl[i>>1]>>1)|((i&1)<<(ss-1));
}
void getinv(int *A,int *B,int len)
{
    if(len==1)
    {
        B[0]=ksm(A[0],mod-2);
        return;
    }
    getinv(A,B,(len+1)>>1);
    init(len<<1);
    for(int i=0;i<len;i++) d1[i]=A[i];
    NTT(B,1);NTT(d1,1);
    for(int i=0;i<lim;i++) B[i]=(ll)B[i]*(2-(ll)B[i]*d1[i]%mod+mod)%mod;
    NTT(B,0);
    for(int i=0;i<len;i++) d1[i]=0;
    for(int i=len;i<lim;i++) B[i]=d1[i]=0;
}
void ln(int *A,int *B,int len)
{
    for(int i=1;i<len;i++) b1[i-1]=(ll)A[i]*i%mod;
    b1[len]=0;
    getinv(A,c1,len);
    init(len<<1);
    NTT(b1,1);NTT(c1,1);
    for(int i=0;i<lim;i++) b1[i]=(ll)b1[i]*c1[i]%mod;
    NTT(b1,0);
    for(int i=1;i<len;i++) B[i]=(ll)b1[i-1]*inv[i]%mod;
    B[0]=0;
    for(int i=0;i<lim;i++) b1[i]=c1[i]=0;
}
void exp(int *A,int *B,int len)
{
    if(len==1)
    {
        B[0]=1;
        return;
    }
    exp(A,B,(len+1)>>1);
    ln(B,e1,len);
    e1[0]=(A[0]+1-e1[0]+mod)%mod;
    for(int i=1;i<len;i++) e1[i]=(A[i]-e1[i]+mod)%mod;
    init(len<<1);
    NTT(B,1);NTT(e1,1);
    for(int i=0;i<lim;i++) B[i]=(ll)B[i]*e1[i]%mod;
    NTT(B,0);
    for(int i=len;i<lim;i++) B[i]=e1[i]=0;
}
int main()
{
    n=read();q=read();
    for(int i=1;i<=n;i++)
    {
        ll x=read();
        if(x==1) cnt1++;
        else cnt2++;
    }
    for(int i=1;i<=q;i++) m[i]=read(),mx=max(mx,m[i]);
    ycl();
    for(int i=2;i<=600000;i++) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
    for(int i=1;i<=n+mx;i++) jc[i]=(ll)jc[i-1]*i%mod;
    ni[n+mx]=ksm(jc[n+mx],mod-2);
    for(int i=n+mx-1;i>=0;i--) ni[i]=(ll)ni[i+1]*(i+1)%mod;
    if(!cnt1) F[0]=1;
    else
    {
        for(int i=0;i<=mx;i++) F[i]=(ll)jc[i+cnt1-1]*ni[i]%mod*ni[cnt1-1]%mod;
        F[mx]=(F[mx]-1+mod)%mod;
    }
    if(!cnt2) G2[0]=1;
    else
    {
        for(int i=0;i<mx;i++) G[i]=yz[i+1];
        ln(G,G1,mx+1);
        for(int i=0;i<=mx;i++) G1[i]=(ll)G1[i]*cnt2%mod;
        for(int i=mx+1;i<lim;i++) G1[i]=0;
        exp(G1,G2,mx+1);
    }
    init((mx+1)<<1);
    NTT(F,1);NTT(G2,1);
    for(int i=0;i<lim;i++) F[i]=(ll)F[i]*G2[i]%mod;
    NTT(F,0);
    for(int i=1;i<=q;i++)
    {
        if(m[i]<n) printf("0\n");
        else printf("%d\n",F[m[i]-n]);
    }
    return 0;
}