题解 P5273 【【模板】多项式幂函数 (加强版)】

· · 题解

做过P5245 【模板】多项式快速幂的都知道,求 F(x)^k 时,若 F(0)=1,就可以直接用\ln+\exp来搞。
也就是这个式子:

F(x)^k= \text e^{k\ln F(x)}

可是现在不保证 F(0)=1 了,我们就得想想办法,把 F(0) 搞成1
只要把 F(x) 的每一项都除以 F(0) ,问题不就解决了么?

F(x)^k=\left(\frac{F(x)}{F(0)} \right)^kF(0)^k

仔细一看:不对,如果 F(0)=0 的话,这个方法就挂掉了。
这时,可以考虑将 F(x) 降次,然后再升次。
具体来讲,我们要找出最小的 t,满足 F(x)t 次项不为0

F(x)^k=\left(\frac{F(x)}{x^t} \right)^kx^{tk}

然后直接模拟过程+套模板即可。
时间复杂度 \Theta(n\log n),比倍增法不知高到哪里去了。

#pragma GCC optimize (2)
#pragma GCC optimize ("unroll-loops")
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 262147
#define ll long long
#define reg register
#define p 998244353
#define add(x,y) (x+y>=p?x+y-p:x+y)
#define dec(x,y) (x<y?x-y+p:x-y)
using namespace std;

inline int power(int a,int t){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%p;
        a = (ll)a*a%p;
        t >>= 1; 
    }
    return res;
}

int rev[N],rt[N],inv[N];
int siz;

void init(int n){
    int w,lim = 1;
    while(lim<=n) lim <<= 1,++siz;
    for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
    inv[1] = rt[lim>>1] = 1;
    w = power(3,(p-1)>>siz);
    for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
    for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
    for(reg int i=2;i<=n;++i) inv[i] = (ll)(p-p/i)*inv[p%i]%p;
}

inline int getlen(int n){
    return 1<<(32-__builtin_clz(n));
}

inline void NTT(int *f,int type,int lim){
    if(type==-1) reverse(f+1,f+lim);
    static unsigned long long a[N];
    reg int x,shift = siz-__builtin_ctz(lim);
    for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
    for(reg int mid=1;mid!=lim;mid<<=1)
    for(reg int j=0;j!=lim;j+=(mid<<1))
    for(reg int k=0;k!=mid;++k){
        x = a[j|k|mid]*rt[mid|k]%p;
        a[j|k|mid] = a[j|k]-x+p;
        a[j|k] += x;
    } 
    for(reg int i=0;i!=lim;++i) f[i] = a[i]%p;
    if(type==1) return;
    x = p-((p-1)>>(siz-shift));
    for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%p;
}

inline void inverse(const int *f,int n,int *R){
    static int g[N],h[N],s[30];
    memset(g,0,getlen(n<<1)<<2);
    int lim = 1,top = 0;
    while(n){
        s[++top] = n;
        n >>= 1;
    }
    g[0] = power(f[0],p-2);
    while(top--){
        n = s[top+1];
        while(lim<=(n<<1)) lim <<= 1;
        memcpy(h,f,(n+1)<<2);
        memset(h+n+1,0,(lim-n)<<2);
        NTT(g,1,lim),NTT(h,1,lim);
        for(reg int i=0;i!=lim;++i) g[i] = g[i]*(2-(ll)g[i]*h[i]%p+p)%p;
        NTT(g,-1,lim);
        if(top) memset(g+n+1,0,(lim-n)<<2);
    }
    memcpy(R,g,(n+1)<<2);
}

inline void log(int *f,int n){
    static int g[N];
    int lim = getlen(n<<1);
    inverse(f,n,g);
    memset(g+n+1,0,(lim-n)<<2);
    for(reg int i=0;i!=n;++i) f[i] = (ll)f[i+1]*(i+1)%p;
    f[n] = 0;
    NTT(f,1,lim),NTT(g,1,lim);
    for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*g[i]%p;
    NTT(f,-1,lim);
    for(reg int i=n;i;--i) f[i] = (ll)f[i-1]*inv[i]%p;
    f[0] = 0;
    memset(f+n+1,0,(lim-n)<<2);
}

inline void exp(int *f,int n){
    static int g[N],h[N],s[30];
    int lim = 1,top = 0;
    memset(g,0,getlen(n<<1)<<2);
    while(n){
        s[++top] = n;
        n >>= 1;
    }
    g[0] = 1;
    while(top--){
        n = s[top+1];
        while(lim<=(n<<1)) lim <<= 1;
        memcpy(h,g,(n+1)<<2);
        memset(h+n+1,0,(lim-n)<<2);
        log(g,n);
        for(reg int i=0;i<=n;++i) g[i] = dec(f[i],g[i]);
        g[0] = add(g[0],1);
        NTT(g,1,lim),NTT(h,1,lim);
        for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*h[i]%p;
        NTT(g,-1,lim);
        memset(g+n+1,0,(lim-n)<<2);
    }
    memcpy(f,g,(n+1)<<2);
}

inline void power(int *f,int n,int k1,int k2){
    register int u,v,shift = 0;
    for(reg int i=0;i<=n&&f[i]==0;++i) ++shift;
    if((ll)shift*k1>n){
        for(reg int i=0;i<=n;++i) printf("0 ");
        return;
    }
    u = power(f[shift],p-2),v = power(f[shift],k2);
    for(reg int i=0;i<=n;++i) f[i] = (ll)f[i+shift]*u%p;
    log(f,n);
    for(reg int i=1;i<=n;++i) f[i] = (ll)f[i]*k1%p;
    exp(f,n);
    shift *= k1;
    for(reg int i=0;i!=shift;++i) printf("0 ");
    for(reg int i=shift;i<=n;++i) printf("%lld ",(ll)f[i-shift]*v%p);
}

int F[N];
char str[N];
int n,k1,k2,k3;

int main(){
    scanf("%d%s",&n,str);
    int ln = strlen(str);
    for(reg int i=0;i!=ln;++i){
        k1 = (10ll*k1+str[i]-'0')%p;
        k2 = (10ll*k2+str[i]-'0')%(p-1);
    }
    for(reg int i=0;i!=min(6,ln);++i) k3 = 10*k3+str[i]-'0';
    for(reg int i=0;i!=n;++i) scanf("%d",&F[i]);
    if(F[0]==0&&k3>=n){
        for(reg int i=0;i!=n;++i) printf("0 ");
        return 0;
    }
    init(n<<1|1);
    power(F,n-1,k1,k2);
    return 0;   
}