题解 P5349 【幂】
算法1:
设
设
设
若
设
则
证明:
现在假设
设
则
设
则
考虑
分开考虑,由于
那么
单独考虑每一个
那么
则
又显然
故
归纳得
下面考虑怎么计算
首先
假设
令
则
故
归纳得
将
多项式多点求值
带入答案式子,
算法2:
令
那么
故有
化一下式子,
最后很明显
代码:
算法1:
#pragma GCC optimize ("Ofast")
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int mod=998244353;
const int N=1048580;
inline void rad(int &_){
static char ch;
while(ch=getchar(),ch<'0'||ch>'9');_=ch-48;
while(ch=getchar(),ch<='9'&&ch>='0')_=_*10+ch-48;
}
inline void swap(int &u,int &v){int o=u;u=v;v=o;}
inline int __(int u){return u<mod?u:u-mod;}
inline int ___(int u){return u<0?u+mod:u;}
int ksm(int u,int v){
int res=1;
for(;v;v>>=1,u=1ll*u*u%mod)
if(v&1)res=1ll*res*u%mod;
return res;
}
int f[N],g[N],rnk[N],c[N],d[N],e[N];
int C[N],D[N],cr[N],dr[N],siz[N];
int h[N],fac[N],inv[N],cnt,n,m,r,s,Ans;
vector<int>vp[N];
void Ntt(int *t,int opt,int len){
int g=3,g_=ksm(g,mod-2);
for(int i=0;i<len;i++)if(i<rnk[i])swap(t[i],t[rnk[i]]);
for(int i=1;i<len;i<<=1){
int wn=ksm(~opt?g:g_,(mod-1)/(i<<1));
for(int j=0,J=i<<1;j<len;j+=J){
int w=1;
for(int k=j;k<i+j;k++,w=1ll*w*wn%mod){
int r=1ll*w*t[i+k]%mod;
t[i+k]=___(t[k]-r);
t[k]=__(t[k]+r);
}
}
}
if(~opt)return;
int ny=ksm(len,mod-2);
for(int i=0;i<len;i++)t[i]=1ll*t[i]*ny%mod;
}
void Inv(int *a,int Len,int *b){
if(Len==1){b[0]=ksm(a[0],mod-2);return;}
Inv(a,(Len+1)>>1,b);
int len=1,_2=-1;
while(len<Len+Len)len<<=1,_2++;
for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
memcpy(e,a,Len<<2);
memset(e+Len,0,(len-Len)<<2);
Ntt(e,1,len);Ntt(b,1,len);
for(int i=0;i<len;i++)b[i]=1ll*(2-1ll*e[i]*b[i]%mod+mod)*b[i]%mod;
Ntt(b,-1,len);
memset(b+Len,0,(len-Len)<<2);
}
void ntt(int *a,int *b,int len1,int len2,int *t){
int len=1,_2=-1;
while(len<len1+len2)len<<=1,_2++;
for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
memcpy(C,a,len1<<2);
memset(C+len1,0,(len-len1)<<2);
memcpy(D,b,len2<<2);
memset(D+len2,0,(len-len2)<<2);
Ntt(C,1,len);Ntt(D,1,len);
for(int i=0;i<len;i++)t[i]=1ll*C[i]*D[i]%mod;
Ntt(t,-1,len);
if(&b[0]==&dr[0])memset(dr,0,len<<2);
}
void Mod(const int *a,const vector<int>&b,int *t,int lena,int lenb){
for(int i=0;i<=lena;i++)c[i]=a[lena-i];
for(int i=0;i<=lenb;i++)d[i]=b[lenb-i];
int len=1;while(len<=lena+lenb)len<<=1;
memset(c+lena-lenb+1,0,(len-lena+lenb-1)<<2);
memset(d+lena-lenb+1,0,(len-lena+lenb-1)<<2);
Inv(d,lena-lenb+1,dr);
ntt(c,dr,lena-lenb+1,lena-lenb+1,cr);
reverse(cr,cr+(lena-lenb+1));
for(int i=0;i<=lenb;i++)d[i]=b[i];
memset(d+lenb+1,0,(len-lenb-1)<<2);
ntt(d,cr,lenb+1,lena-lenb+1,c);
for(int i=0;i<lenb;i++)t[i]=___(a[i]-c[i]);
}
void Solve(int now,int ls,int rs,int *a){
if(ls==rs){h[++cnt]=a[0];return;}
int noww=now<<1,nrs=ls+rs>>1,b[siz[now]+1];
Mod(a,vp[noww],b,siz[now]-1,siz[noww]);
Solve(noww,ls,nrs,b);
Mod(a,vp[noww|1],b,siz[now]-1,siz[noww|1]);
Solve(noww|1,nrs+1,rs,b);
}
void Solve(int now,int ls,int rs){
siz[now]=rs-ls+1;
if(ls==rs){
vp[now].resize(2);
vp[now][0]=mod-g[ls];
vp[now][1]=1;
return;
}
int noww=now<<1,nrs=ls+rs>>1;
Solve(noww,ls,nrs);Solve(noww|1,nrs+1,rs);
int len=1,_2=-1;
while(len<=siz[now])len<<=1,_2++;
for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
for(int i=0;i<=siz[noww];i++)c[i]=vp[noww][i];
for(int i=0;i<=siz[noww|1];i++)d[i]=vp[noww|1][i];
memset(c+siz[noww]+1,0,(len-siz[noww]-1)<<2);
memset(d+siz[noww|1]+1,0,(len-siz[noww|1]-1)<<2);
Ntt(c,1,len);Ntt(d,1,len);
for(int i=0;i<len;i++)c[i]=1ll*c[i]*d[i]%mod;
Ntt(c,-1,len);
vp[now].resize(siz[now]+1);
for(int i=0;i<=siz[now];i++)vp[now][i]=c[i];
}
int main(){
rad(n);m=n+1;rad(r);
for(int i=0;i<=n;i++)rad(f[i]);
for(int i=1;i<=m;i++)g[i]=i;
s=f[0];
Solve(1,1,m);
Solve(1,1,m,f);
fac[0]=fac[1]=inv[0]=inv[1]=1;
for(int i=2;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
inv[n]=ksm(fac[n],mod-2);
for(int i=n-1;i>=2;i--)inv[i]=1ll*inv[i+1]*(i+1)%mod;
for(int i=0;i<=n;i++)f[i]=___((i&1?-1:1)*inv[i]);
for(int i=0;i<=n;i++)g[i]=1ll*h[i+1]*inv[i]%mod;
ntt(f,g,n+1,n+1,h);
r=1ll*r*ksm(1-r+mod,mod-2)%mod;
for(int i=0,j=1;i<=n;i++,j=1ll*j*r%mod)
h[i]=1ll*h[i]*fac[i]%mod*j%mod;
for(int i=0;i<=n;i++)Ans=__(Ans+h[i]);
Ans=__(1ll*Ans*r%mod+s);
printf("%d\n",Ans);
}
算法2:
#include<cstdio>
const int N=262150;
const int mod=998244353;
int n,Ans,r,r_,A[N],fac[N],inv[N];
int f[N],g[N],c[N],rnk[N];
int ksm(int u,int v){
int res=1;
for(;v;v>>=1,u=1ll*u*u%mod)
if(v&1)res=1ll*res*u%mod;
return res;
}
inline void swap(int &u,int &v){int o=u;u=v;v=o;}
inline int _(int u){return u<mod?u:u-mod;}
inline int __(int u){return u<0?u+mod:u;}
void ntt(int *t,int opt,int len){
int g=3,g_=ksm(g,mod-2);
for(int i=0;i<len;i++)if(i<rnk[i])swap(t[i],t[rnk[i]]);
for(int i=1;i<len;i<<=1){
int wn=ksm(~opt?g:g_,(mod-1)/(i<<1));
for(int j=0,J=i<<1;j<len;j+=J){
int w=1;
for(int k=j;k<i+j;k++,w=1ll*w*wn%mod){
int r=1ll*w*t[i+k]%mod;
t[i+k]=__(t[k]-r);
t[k]=_(t[k]+r);
}
}
}
if(~opt)return;
int ny=ksm(len,mod-2);
for(int i=0;i<len;i++)t[i]=1ll*t[i]*ny%mod;
}
void Inv(int Len,int *a,int *b){
if(Len==1){b[0]=ksm(a[0],mod-2);return;}
Inv((Len+1)>>1,a,b);
int len=1,_2=-1;
while(len<Len+Len)len<<=1,_2++;
for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
for(int i=0;i<Len;i++)c[i]=a[i];
for(int i=Len;i<len;i++)c[i]=0;
ntt(c,1,len);ntt(b,1,len);
for(int i=0;i<len;i++)
b[i]=1ll*(2-1ll*c[i]*b[i]%mod+mod)*b[i]%mod;
ntt(b,-1,len);
for(int i=Len;i<len;i++)b[i]=0;
}
int main(){
scanf("%d%d",&n,&r);r_=1ll*r*ksm(1-r+mod,mod-2)%mod;
for(int i=0;i<=n;i++)scanf("%d",&A[i]);
fac[0]=inv[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
inv[n]=ksm(fac[n],mod-2)%mod;
for(int i=n-1;i>=1;i--)inv[i]=1ll*inv[i+1]*(i+1)%mod;
for(int i=1;i<=n;i++)g[i]=mod-1ll*inv[i]*r_%mod;
g[0]=1;
Inv(n+1,g,f);
for(int i=0;i<=n;i++)f[i]=1ll*f[i]*fac[i]%mod*ksm(1-r+mod,mod-2)%mod;
for(int i=0;i<=n;i++)Ans=_(Ans+1ll*f[i]*A[i]%mod);
printf("%d\n",Ans);
}