P9551 「PHOI-1」斗之魂 题解
这是本蒟蒻第三十三次写的题解,如有错误点请好心指出!
显然如果小 X 用第
去分母得:
两边各加上一个
发现右边式子是个完全平方公式,因式分解得:
因为
将
预处理好因子个数之后,就可以用 dp 求方案数了。设
当
其中
发现转移方程均与击败 BOSS 的顺序无关,我们可以先全部处理第
设状态函数
其中
然后答案函数为
用多项式快速幂即可做到时间复杂度为
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
typedef long long ll;
const ll mod=998244353,gg=3,ggi=(mod+1)/3;
ll yz[250005]={0,1},zs[250005],tot[250005],cnt;
int inv[600005]={0,1},jc[600005]={1},ni[600005];
int n,q,mx,m[100005],bl[600005],F[600005],G[600005],G1[600005],G2[600005],b1[600005],c1[600005],d1[600005],e1[600005],lim,ss,cnt1,cnt2;
bool bz[250005];
inline ll read()
{
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
inline void ycl()
{
for(int i=2;i<=mx;i++)
{
if(!bz[i])
{
zs[++cnt]=i;
tot[i]=1;
yz[i]=3;
}
for(int j=1;j<=cnt;j++)
{
if(i*zs[j]>mx) break;
bz[i*zs[j]]=1;
if(i%zs[j]==0)
{
tot[i*zs[j]]=tot[i]+1;
yz[i*zs[j]]=yz[i]/(2*tot[i]+1)*(2*tot[i]+3);
break;
}
tot[i*zs[j]]=1;
yz[i*zs[j]]=yz[i]*3;
}
}
}
inline int ksm(int x,int y)
{
int res=1;
while(y)
{
if(y&1) res=(ll)res*x%mod;
x=(ll)x*x%mod;
y>>=1;
}
return res;
}
void NTT(int *A,int type)
{
for(int i=0;i<lim;i++)
if(i<bl[i]) swap(A[i],A[bl[i]]);
for(int i=2;i<=lim;i<<=1)
{
int mid=i>>1;
ll gn=ksm(type==1?gg:ggi,(mod-1)/i);
for(int j=0;j<lim;j+=i)
{
ll mi=1;
for(int k=j;k<j+mid;k++,mi=mi*gn%mod)
{
ll ax=A[k],ay=(ll)mi*A[k+mid]%mod;
A[k]=(ax+ay)%mod;
A[k+mid]=(ax-ay+mod)%mod;
}
}
}
if(type==0)
{
ll inv=ksm(lim,mod-2);
for(int i=0;i<lim;i++) A[i]=(ll)A[i]*inv%mod;
}
}
void init(int len)
{
lim=1,ss=0;
while(lim<=len) lim<<=1,ss++;
for(int i=0;i<lim;i++) bl[i]=(bl[i>>1]>>1)|((i&1)<<(ss-1));
}
void getinv(int *A,int *B,int len)
{
if(len==1)
{
B[0]=ksm(A[0],mod-2);
return;
}
getinv(A,B,(len+1)>>1);
init(len<<1);
for(int i=0;i<len;i++) d1[i]=A[i];
NTT(B,1);NTT(d1,1);
for(int i=0;i<lim;i++) B[i]=(ll)B[i]*(2-(ll)B[i]*d1[i]%mod+mod)%mod;
NTT(B,0);
for(int i=0;i<len;i++) d1[i]=0;
for(int i=len;i<lim;i++) B[i]=d1[i]=0;
}
void ln(int *A,int *B,int len)
{
for(int i=1;i<len;i++) b1[i-1]=(ll)A[i]*i%mod;
b1[len]=0;
getinv(A,c1,len);
init(len<<1);
NTT(b1,1);NTT(c1,1);
for(int i=0;i<lim;i++) b1[i]=(ll)b1[i]*c1[i]%mod;
NTT(b1,0);
for(int i=1;i<len;i++) B[i]=(ll)b1[i-1]*inv[i]%mod;
B[0]=0;
for(int i=0;i<lim;i++) b1[i]=c1[i]=0;
}
void exp(int *A,int *B,int len)
{
if(len==1)
{
B[0]=1;
return;
}
exp(A,B,(len+1)>>1);
ln(B,e1,len);
e1[0]=(A[0]+1-e1[0]+mod)%mod;
for(int i=1;i<len;i++) e1[i]=(A[i]-e1[i]+mod)%mod;
init(len<<1);
NTT(B,1);NTT(e1,1);
for(int i=0;i<lim;i++) B[i]=(ll)B[i]*e1[i]%mod;
NTT(B,0);
for(int i=len;i<lim;i++) B[i]=e1[i]=0;
}
int main()
{
n=read();q=read();
for(int i=1;i<=n;i++)
{
ll x=read();
if(x==1) cnt1++;
else cnt2++;
}
for(int i=1;i<=q;i++) m[i]=read(),mx=max(mx,m[i]);
ycl();
for(int i=2;i<=600000;i++) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
for(int i=1;i<=n+mx;i++) jc[i]=(ll)jc[i-1]*i%mod;
ni[n+mx]=ksm(jc[n+mx],mod-2);
for(int i=n+mx-1;i>=0;i--) ni[i]=(ll)ni[i+1]*(i+1)%mod;
if(!cnt1) F[0]=1;
else
{
for(int i=0;i<=mx;i++) F[i]=(ll)jc[i+cnt1-1]*ni[i]%mod*ni[cnt1-1]%mod;
F[mx]=(F[mx]-1+mod)%mod;
}
if(!cnt2) G2[0]=1;
else
{
for(int i=0;i<mx;i++) G[i]=yz[i+1];
ln(G,G1,mx+1);
for(int i=0;i<=mx;i++) G1[i]=(ll)G1[i]*cnt2%mod;
for(int i=mx+1;i<lim;i++) G1[i]=0;
exp(G1,G2,mx+1);
}
init((mx+1)<<1);
NTT(F,1);NTT(G2,1);
for(int i=0;i<lim;i++) F[i]=(ll)F[i]*G2[i]%mod;
NTT(F,0);
for(int i=1;i<=q;i++)
{
if(m[i]<n) printf("0\n");
else printf("%d\n",F[m[i]-n]);
}
return 0;
}