题解 P5349 【幂】

· · 题解

算法1:

\{a_n\}是等差数列,首项为a_1,公差为d

\{b_n\}是等比数列,首项为b_1,公比为q(\,\left|\,q\,\right|\lt1)

S(x)=\sum_{i=1}^{\infty}a_ib_i,则

S(x)=a_1b_1+(a_1+d)b_1q+(a_1+2d)b_1q^2+\cdots \ \ \ \ \ \ \ \,\,=b_1\big[a_1(1+q+q^2+\cdots)+dq(1+2q+3q^2+\cdots)\big] \ \ \ \ \ \ \ \,\,=b_1\big[a_1\frac{1}{1-q}+dq(\frac{1}{1-q})^2\big] \ \ \ \ \ \ \ \,\,=\frac{b_1}{1-q}\big(a_1+d\cdot \frac{q}{1-q}\big)

\{a_n\}m阶等差数列。

d_0,d_1,d_2,\dots d_m为数列\{a_n\}0,1,2,\dots m阶差数列的首项。

S(x)=\sum\limits_{i=1}^{\infty}a_ib_i=\frac{b_1}{1-q}\sum\limits_{i=0}^m d_i(\frac{q}{1-q})^i

证明:m=1时,显然d_0=a_1d_1=d,原式成立。

现在假设m=r时成立,下面证明m=r+1时也成立。

\{\Delta_n\}=\{a_{n+1}-a_n\},那么\{\Delta_n\}r阶等差数列

\sum\limits_{i=1}^{\infty}\Delta_ib_i=\frac{b_1}{1-q}\sum\limits_{k=0}^r d_{k+1}(\frac{q}{1-q})^k

\sum\limits_{i=1}^{\infty}\Delta_ib_i的前n项和为D_n\sum\limits_{i=1}^{\infty}a_ib_i的前n项和为S_n

\lim\limits_{n\to \infty}D_n=\frac{b_1}{1-q}\sum\limits_{k=0}^r d_{k+1}(\frac{q}{1-q})^k

考虑S_n=a_1b_1+a_2b_2+a_3b_3+\cdots+a_nb_n

qS_n=a_1b_2+a_2b_3+a_3b_4+\cdots+a_nb_{n+1} \big(1-q\big)S_n=a_1b_1-a_nb_{n+1}+(a_2-a_1)b_2+(a_3-a_2)b_3+\cdots+(a_n-a_{n-1})b_n \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \,\,=a_1b_1-a_nb_{n+1}+\Delta_1b_2+\Delta_2b_3+\cdots+\Delta_{n-1}b_n \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \,\,=a_1b_1-a_nb_{n+1}+q\big(\Delta_1b_1+\Delta_2b_2+\cdots+\Delta_{n-1}b_{n-1}\big) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \,\,=a_1b_1-a_nb_{n+1}+qD_{n-1} S(x)=\lim\limits_{n\to \infty}S_n=\frac{1}{1-q}\big(a_1b_1-\lim_{n\to \infty}a_nb_{n+1}+q\lim_{n\to \infty}D_{n-1}\big)

分开考虑,由于\{a_n\}r+1阶的等差数列,那么必然可以表达成r+1次多项式。

那么\lim\limits_{n\to \infty}a_nb_{n+1}=\lim\limits_{n\to \infty}\sum\limits_{k=0}^{r+1}c_kn^k\cdot b_1q^n=b_1\sum\limits_{k=0}^{r+1}c_k\lim\limits_{n\to \infty}n^kq^n

单独考虑每一个k,那么有

\lim\limits_{n\to \infty}\frac{|(n+1)^kq^{n+1}|}{|n^kq^n|}=\lim\limits_{n\to \infty}\big|q(1+\frac{1}{n})^k\big|=\big|q\big|\lt 1

那么\lim\limits_{n\to \infty}n^kq^n=0\ \ (k=0,1,2,\dots,r+1)

\lim\limits_{n\to \infty}a_nb_{n+1}=0

又显然\lim\limits_{n\to \infty}D_{n-1}=\lim\limits_{n\to \infty}D_{n}=\frac{b_1}{1-q}\sum\limits_{k=0}^r d_{k+1}(\frac{q}{1-q})^k,所以

S(x)=\frac{1}{1-q}\big(a_1b_1+q\frac{b_1}{1-q}\sum\limits_{k=0}^r d_{k+1}(\frac{q}{1-q})^k\big) \ \ \ \ \ \ \ \,\,=\frac{b_1}{1-q}\big(a_1+\frac{q}{1-q}\sum\limits_{k=0}^r d_{k+1}(\frac{q}{1-q})^k\big) \ \ \ \ \ \ \ \,\,=\frac{b_1}{1-q}\big(a_1+\sum\limits_{k=1}^{r+1} d_k(\frac{q}{1-q})^k\big) \ \ \ \ \ \ \ \,\,=\frac{b_1}{1-q}\sum\limits_{k=0}^{r+1} d_k(\frac{q}{1-q})^k

m=r+1时成立。

归纳得S(x)=\frac{b_1}{1-q}\sum\limits_{k=0}^m d_k(\frac{q}{1-q})^k

下面考虑怎么计算d_k

d_k=\sum\limits_{i=0}^k(-1)^i\binom{k}{i}a_{k+1-i}

首先d_0=a_0,符合上面式子。

假设k=r时上式成立,下面证明k=r+1时成立。

p_r表示a_2,a_3,\dots a_{r+2}作差分时r阶差数列的首项,那么p_r=\sum\limits_{i=0}^r(-1)^i\binom{r}{i}a_{r+2-i}

d_{r+1}=p_r-d_r=\sum\limits_{i=0}^r(-1)^i\binom{r}{i}a_{r+2-i}-\sum\limits_{i=0}^r(-1)^i\binom{r}{i}a_{r+1-i}

=\sum\limits_{i=1}^r(-1)^i\binom{r}{i}a_{r+2-i}-\sum\limits_{i=0}^{r-1}(-1)^i\binom{r}{i}a_{r+1-i}+a_{r+2}-(-1)^ra_1 =\sum\limits_{i=1}^r(-1)^i\binom{r}{i}a_{r+2-i}+\sum\limits_{i=1}^r(-1)^i\binom{r}{i-1}a_{r+2-i}+a_{r+2}-(-1)^ra_1 =\sum\limits_{i=1}^r(-1)^i\big(\binom{r}{i}+\binom{r}{i-1}\big)a_{r+2-i}+a_{r+2}-(-1)^ra_1 =\sum\limits_{i=1}^r(-1)^i\binom{r+1}{i}a_{r+2-i}+a_{r+2}+(-1)^{r+1}a_1 =\sum\limits_{i=0}^{r+1}(-1)^i\binom{r+1}{i}a_{r+2-i}

k=r+1时成立。

归纳得d_k=\sum\limits_{i=0}^k(-1)^i\binom{k}{i}a_{k+1-i}

d_k的式子拆开。

d_k=\sum\limits_{i=0}^k(-1)^i\binom{k}{i}a_{k+1-i}=k!\sum\limits_{i=0}^k\frac{(-1)^i}{i!}\cdot\frac{a_{k+1-i}}{(k-i)!}

多项式多点求值+\mathrm{NTT}即可。

带入答案式子,Ans=S(x)+a_0b_0=\frac{b_1}{1-q}\sum\limits_{k=0}^m d_k(\frac{q}{1-q})^k+a_0

算法2:

f_k=\sum\limits_{n=0}^{\infty}n^kr^n

那么 r\cdot f_k=\sum\limits_{n=0}^{\infty}n^kr^{n+1}=\sum\limits_{n=1}^{\infty}(n-1)^kr^n

(1-r)f_k=\sum\limits_{n=1}^{\infty}\big(n^k-(n-1)^k\big)r^n ~~~~~~~~~~~~~\,=r\sum\limits_{n=0}^{\infty}\big((n+1)^k-n^k\big)r^n ~~~~~~~~~~~~~\,=r\sum\limits_{n=0}^{\infty}\sum\limits_{i=0}^{k-1}\binom{k}{i}n^ir^n ~~~~~~~~~~~~~\,=r\sum\limits_{i=0}^{k-1}\binom{k}{i}\sum\limits_{n=0}^{\infty}n^ir^n ~~~~~~~~~~~~~\,=r\sum\limits_{i=0}^{k-1}\binom{k}{i}f_i

故有f_k=\frac{r}{1-r}\sum\limits_{i=0}^{k-1}\binom{k}{i}f_i

化一下式子,\frac{f_k}{k!}=\sum\limits_{i=0}^{k-1}\frac{f_i}{i!}\big(\frac{r}{1-r}\cdot \frac{1}{(k-i)!}\big),很明显分治fft可以解决这个问题。

最后很明显Ans=\sum\limits_{i=0}^ma_if_i

代码:

算法1:

#pragma GCC optimize ("Ofast")
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int mod=998244353;
const int N=1048580;
inline void rad(int &_){
    static char ch;
    while(ch=getchar(),ch<'0'||ch>'9');_=ch-48;
    while(ch=getchar(),ch<='9'&&ch>='0')_=_*10+ch-48;
}
inline void swap(int &u,int &v){int o=u;u=v;v=o;}
inline int __(int u){return u<mod?u:u-mod;}
inline int ___(int u){return u<0?u+mod:u;}
int ksm(int u,int v){
    int res=1;
    for(;v;v>>=1,u=1ll*u*u%mod)
    if(v&1)res=1ll*res*u%mod;
    return res;
}
int f[N],g[N],rnk[N],c[N],d[N],e[N];
int C[N],D[N],cr[N],dr[N],siz[N];
int h[N],fac[N],inv[N],cnt,n,m,r,s,Ans;
vector<int>vp[N];
void Ntt(int *t,int opt,int len){
    int g=3,g_=ksm(g,mod-2);
    for(int i=0;i<len;i++)if(i<rnk[i])swap(t[i],t[rnk[i]]);
    for(int i=1;i<len;i<<=1){
        int wn=ksm(~opt?g:g_,(mod-1)/(i<<1));
        for(int j=0,J=i<<1;j<len;j+=J){
            int w=1;
            for(int k=j;k<i+j;k++,w=1ll*w*wn%mod){
                int r=1ll*w*t[i+k]%mod;
                t[i+k]=___(t[k]-r);
                t[k]=__(t[k]+r);
            }
        }
    }
    if(~opt)return;
    int ny=ksm(len,mod-2);
    for(int i=0;i<len;i++)t[i]=1ll*t[i]*ny%mod;
}
void Inv(int *a,int Len,int *b){
    if(Len==1){b[0]=ksm(a[0],mod-2);return;}
    Inv(a,(Len+1)>>1,b);
    int len=1,_2=-1;
    while(len<Len+Len)len<<=1,_2++;
    for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
    memcpy(e,a,Len<<2);
    memset(e+Len,0,(len-Len)<<2);
    Ntt(e,1,len);Ntt(b,1,len);
    for(int i=0;i<len;i++)b[i]=1ll*(2-1ll*e[i]*b[i]%mod+mod)*b[i]%mod;
    Ntt(b,-1,len);
    memset(b+Len,0,(len-Len)<<2);
}
void ntt(int *a,int *b,int len1,int len2,int *t){
    int len=1,_2=-1;
    while(len<len1+len2)len<<=1,_2++;
    for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
    memcpy(C,a,len1<<2);
    memset(C+len1,0,(len-len1)<<2);
    memcpy(D,b,len2<<2);
    memset(D+len2,0,(len-len2)<<2); 
    Ntt(C,1,len);Ntt(D,1,len);
    for(int i=0;i<len;i++)t[i]=1ll*C[i]*D[i]%mod;
    Ntt(t,-1,len);
    if(&b[0]==&dr[0])memset(dr,0,len<<2);
}
void Mod(const int *a,const vector<int>&b,int *t,int lena,int lenb){
    for(int i=0;i<=lena;i++)c[i]=a[lena-i];
    for(int i=0;i<=lenb;i++)d[i]=b[lenb-i]; 
    int len=1;while(len<=lena+lenb)len<<=1;
    memset(c+lena-lenb+1,0,(len-lena+lenb-1)<<2);
    memset(d+lena-lenb+1,0,(len-lena+lenb-1)<<2);
    Inv(d,lena-lenb+1,dr);
    ntt(c,dr,lena-lenb+1,lena-lenb+1,cr);
    reverse(cr,cr+(lena-lenb+1));
    for(int i=0;i<=lenb;i++)d[i]=b[i];
    memset(d+lenb+1,0,(len-lenb-1)<<2);
    ntt(d,cr,lenb+1,lena-lenb+1,c);
    for(int i=0;i<lenb;i++)t[i]=___(a[i]-c[i]);
}
void Solve(int now,int ls,int rs,int *a){
    if(ls==rs){h[++cnt]=a[0];return;}
    int noww=now<<1,nrs=ls+rs>>1,b[siz[now]+1];
    Mod(a,vp[noww],b,siz[now]-1,siz[noww]);
    Solve(noww,ls,nrs,b);
    Mod(a,vp[noww|1],b,siz[now]-1,siz[noww|1]);
    Solve(noww|1,nrs+1,rs,b);
}
void Solve(int now,int ls,int rs){
    siz[now]=rs-ls+1;
    if(ls==rs){
        vp[now].resize(2);
        vp[now][0]=mod-g[ls];
        vp[now][1]=1;
        return;
    }
    int noww=now<<1,nrs=ls+rs>>1;
    Solve(noww,ls,nrs);Solve(noww|1,nrs+1,rs);
    int len=1,_2=-1;
    while(len<=siz[now])len<<=1,_2++;
    for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
    for(int i=0;i<=siz[noww];i++)c[i]=vp[noww][i];
    for(int i=0;i<=siz[noww|1];i++)d[i]=vp[noww|1][i];
    memset(c+siz[noww]+1,0,(len-siz[noww]-1)<<2);
    memset(d+siz[noww|1]+1,0,(len-siz[noww|1]-1)<<2);
    Ntt(c,1,len);Ntt(d,1,len);
    for(int i=0;i<len;i++)c[i]=1ll*c[i]*d[i]%mod;
    Ntt(c,-1,len);
    vp[now].resize(siz[now]+1);
    for(int i=0;i<=siz[now];i++)vp[now][i]=c[i];
}
int main(){
    rad(n);m=n+1;rad(r);
    for(int i=0;i<=n;i++)rad(f[i]);
    for(int i=1;i<=m;i++)g[i]=i;
    s=f[0];
    Solve(1,1,m);
    Solve(1,1,m,f);
    fac[0]=fac[1]=inv[0]=inv[1]=1;
    for(int i=2;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
    inv[n]=ksm(fac[n],mod-2);
    for(int i=n-1;i>=2;i--)inv[i]=1ll*inv[i+1]*(i+1)%mod;
    for(int i=0;i<=n;i++)f[i]=___((i&1?-1:1)*inv[i]);
    for(int i=0;i<=n;i++)g[i]=1ll*h[i+1]*inv[i]%mod;
    ntt(f,g,n+1,n+1,h);
    r=1ll*r*ksm(1-r+mod,mod-2)%mod;
    for(int i=0,j=1;i<=n;i++,j=1ll*j*r%mod)
    h[i]=1ll*h[i]*fac[i]%mod*j%mod;
    for(int i=0;i<=n;i++)Ans=__(Ans+h[i]);
    Ans=__(1ll*Ans*r%mod+s);
    printf("%d\n",Ans);
}

算法2:

#include<cstdio>
const int N=262150;
const int mod=998244353;
int n,Ans,r,r_,A[N],fac[N],inv[N];
int f[N],g[N],c[N],rnk[N];
int ksm(int u,int v){
    int res=1;
    for(;v;v>>=1,u=1ll*u*u%mod)
    if(v&1)res=1ll*res*u%mod;
    return res;
}
inline void swap(int &u,int &v){int o=u;u=v;v=o;}
inline int _(int u){return u<mod?u:u-mod;}
inline int __(int u){return u<0?u+mod:u;}
void ntt(int *t,int opt,int len){
    int g=3,g_=ksm(g,mod-2);
    for(int i=0;i<len;i++)if(i<rnk[i])swap(t[i],t[rnk[i]]);
    for(int i=1;i<len;i<<=1){
        int wn=ksm(~opt?g:g_,(mod-1)/(i<<1));
        for(int j=0,J=i<<1;j<len;j+=J){
            int w=1;
            for(int k=j;k<i+j;k++,w=1ll*w*wn%mod){
                int r=1ll*w*t[i+k]%mod;
                t[i+k]=__(t[k]-r);
                t[k]=_(t[k]+r);
            }
        }
    }
    if(~opt)return;
    int ny=ksm(len,mod-2);
    for(int i=0;i<len;i++)t[i]=1ll*t[i]*ny%mod;
}
void Inv(int Len,int *a,int *b){
    if(Len==1){b[0]=ksm(a[0],mod-2);return;}
    Inv((Len+1)>>1,a,b);
    int len=1,_2=-1;
    while(len<Len+Len)len<<=1,_2++;
    for(int i=0;i<len;i++)rnk[i]=(rnk[i>>1]>>1)|((i&1)<<_2);
    for(int i=0;i<Len;i++)c[i]=a[i];
    for(int i=Len;i<len;i++)c[i]=0;
    ntt(c,1,len);ntt(b,1,len);
    for(int i=0;i<len;i++)
    b[i]=1ll*(2-1ll*c[i]*b[i]%mod+mod)*b[i]%mod;
    ntt(b,-1,len);
    for(int i=Len;i<len;i++)b[i]=0;
}
int main(){
    scanf("%d%d",&n,&r);r_=1ll*r*ksm(1-r+mod,mod-2)%mod;
    for(int i=0;i<=n;i++)scanf("%d",&A[i]);
    fac[0]=inv[0]=1;
    for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
    inv[n]=ksm(fac[n],mod-2)%mod;
    for(int i=n-1;i>=1;i--)inv[i]=1ll*inv[i+1]*(i+1)%mod;
    for(int i=1;i<=n;i++)g[i]=mod-1ll*inv[i]*r_%mod;
    g[0]=1;
    Inv(n+1,g,f);
    for(int i=0;i<=n;i++)f[i]=1ll*f[i]*fac[i]%mod*ksm(1-r+mod,mod-2)%mod;
    for(int i=0;i<=n;i++)Ans=_(Ans+1ll*f[i]*A[i]%mod);
    printf("%d\n",Ans);
}