[MtOI2019] T6 Solution

· · 题解

upd:之前式子有一点锅,现已修复。

这其实是一个很水的套路题。。

首先容易看出来 f(x,0) 是个线性递推的形式,要求的是其 k 阶前缀积。
要求乘积不太好搞,可以对 2 取一下对数,化乘为加。

于是问题转化为:

一个数列 a

a_n=n\space(n\le42) a_n=\sum\limits_{i=1}^{42}ia_{n-i}\space(n\ge 43)

求它 k 阶前缀和的第 n 项。

关于线性递推式的高阶前缀和有一个优美的性质。

设数列 a 的递推系数为 f,那么在 f 前面加个 -1 ,然后做 k 阶差分得到的序列即 ak 阶前缀和的递推式。( 当然要在后面扩展 k 项,同时最后去掉 -1 )

在此简短证明一下,设:

a_n=\sum\limits_{i=1}^kf_ia_{n-i} b_n=\sum\limits_{i=1}^na_i b_n=b_{n-1}+a_n=b_{n-1}+\sum\limits_{i=1}^kf_ia_{n-i} = b_{n-1}+\sum\limits_{i=1}^kf_i(b_{n-i}-b_{n-i-1})

后面的那个求和展开,可以得到很多形如 (f_i-f_{i-1})b_{n-i} 的式子,这就是一个很明显的差分形式,接下来的证明就很容易了。

得到 k 阶前缀和的递推式后,把 ak 阶前缀和的前几项也求出来,然后直接上线性递推板子即可。

不过要注意的是我们刚才取了个 \log,所以以上运算都要对 \color{red} 998244352 取模,这需要用到任意模数。7 次 FFT 的做法常数过大,不能通过;需要使用 4 次 FFT 的做法。

还有就是模数不是素数时,算组合数很麻烦,所以直接用倍增多项式快速幂计算高阶差分或前缀和即可。

时间复杂度 \Theta(k\log^2 k+k\log k\log n)

std:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 65539
#define ll long long
#define reg register
#define p 998244352
#define pi 3.141592653589793
using namespace std;

struct complex{
    double x,y;
    inline complex(double x=0,double y=0):x(x),y(y){}

    inline complex operator + (const complex& b) const{ return complex(x+b.x,y+b.y); }
    inline complex operator - (const complex& b) const{ return complex(x-b.x,y-b.y); }
    inline complex operator * (const complex& b) const{ return complex(x*b.x-y*b.y,x*b.y+y*b.x); }
    inline complex operator / (const int& b) const{ return complex(x/b,y/b); }
    inline complex operator ~ () const{ return complex(x,-y); }
}rt[N];

struct matrix{
    int a[43][43];
    int siz;
    inline matrix(int _siz=0):siz(_siz){ memset(a,0,sizeof(a)); }

    inline matrix operator * (const matrix& b) const{
        matrix res = matrix(siz);
        for(reg int i=0;i!=siz;++i)
        for(reg int j=0;j!=siz;++j)
        for(reg int k=0;k!=siz;++k)
            res.a[i][j] = (res.a[i][j]+(ll)a[i][k]*b.a[k][j])%p;
        return res;    
    }
};

inline matrix mat_pow(matrix a,ll t){
    matrix res = matrix(a.siz);
    for(reg int i=0;i!=a.siz;++i) res.a[i][i] = 1;
    while(1){
        if(t&1) res = res*a;
        t >>= 1;
        if(t==0) break;
        a = a*a;
    }
    return res;
}

int rev[N];
int siz;

inline int add(int a,int b){ return a+b>=p?a+b-p:a+b; }
inline int dec(int a,int b){ return a<b?a-b+p:a-b; }
inline int getlen(int n){ return 1<<(32-__builtin_clz(n)); }

inline void init(int n){
    int lim = 1;
    while(lim<=n) lim <<= 1,++siz;
    for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
    rt[lim>>1] = complex(1,0);
    for(reg int i=1;i!=(lim>>1);++i) rt[i+(lim>>1)] = complex(cos(pi*2/lim*i),sin(pi*2/lim*i));
    for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
}

inline void dft(complex *f,int lim){
    static complex a[N];
    int shift = siz-__builtin_ctz(lim);
    for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
    for(reg int mid=1;mid!=lim;mid<<=1)
    for(reg int j=0;j!=lim;j+=(mid<<1))
    for(reg int k=0;k!=mid;++k){
        complex x = a[j|k|mid]*rt[mid|k];
        a[j|k|mid] = a[j|k]-x;
        a[j|k] = a[j|k]+x;
    }
    for(reg int i=0;i!=lim;++i) f[i] = a[i];
}

inline void idft(complex *f,int lim){
    reverse(f+1,f+lim);
    dft(f,lim);
    for(reg int i=0;i!=lim;++i) f[i] = f[i]/lim;
}

inline void multiply(const int *A,const int *B,int n,int m,int *R,int len,bool flag){
    static complex f[N],g[N],h[N],q[N];
    complex t,f0,f1,g0,g1;
    ll x,y,z;
    int lim = getlen(n+m);
    for(reg int i=0;i!=lim;++i){
        f[i] = complex(A[i]>>15,A[i]&32767);
        g[i] = complex(B[i]>>15,B[i]&32767);
    }
    dft(f,lim);
    if(flag) for(reg int i=0;i!=lim;++i) g[i] = f[i];
    else dft(g,lim);
    for(reg int i=0;i!=lim;++i){
        t = ~f[i?lim-i:0];
        f0 = (f[i]-t)*complex(0,-0.5),f1 = (f[i]+t)*0.5;
        t = ~g[i?lim-i:0];
        g0 = (g[i]-t)*complex(0,-0.5),g1 = (g[i]+t)*0.5;
        h[i] = f1*g1;
        q[i] = f1*g0 + f0*g1 + f0*g0*complex(0,1);
    }
    idft(h,lim),idft(q,lim);
    for(reg int i=0;i<=len;++i){
        x = (ll)(h[i].x+0.5)%p<<30;
        y = (ll)(q[i].x+0.5)<<15;
        z = q[i].y+0.5;
        R[i] = (x+y+z)%p;
    }
    memset(R+len+1,0,(lim-len)<<2);
}

inline void inverse(const int *f,int n,int *R){ 
    static int g[N],h[N],st[30];
    memset(g,0,getlen(n<<1)<<2);
    int top = 0,lim =1 ;
    while(n){
        st[++top] = n;
        n >>= 1;
    }
    g[0] = 1;
    while(top--){
        n = st[top+1];
        while(lim<=(n<<1)) lim <<= 1;
        memcpy(h,f,(n+1)<<2);
        memset(h+n+1,0,(lim-n)<<2);
        multiply(h,g,n,n>>1,h,n,false);
        multiply(h,g,n,n>>1,h,n,false);
        for(reg int i=(n>>1);i<=n;++i) g[i] = dec(add(g[i],g[i]),h[i]);
    }
    memcpy(R,g,(n+1)<<2);
}

inline void divide(const int *f,const int *ig,int n,int m,int *R){
    static int A[N],B[N];
    memcpy(A,f,(n+1)<<2),memcpy(B,ig,(m+1)<<2);
    reverse(A,A+n+1);
    int tt = n-m,lim = getlen((n-m)<<1);
    memset(A+tt+1,0,(lim-tt)<<2);
    for(reg int i=min(m,tt)+1;i!=lim;++i) B[i] = 0;
    multiply(A,B,tt,tt,R,n-m,false);
    reverse(R,R+tt+1);
}

inline void mod(const int *f,const int *g,const int *ig,int n,int m,int *R){
    if(n<m) return;
    static int A[N],B[N];
    memcpy(B,f,(n+1)<<2);
    int lim = getlen(n);
    divide(f,ig,n,m,R);
    memcpy(A,g,(m+1)<<2);
    memset(A+m+1,0,(lim-m+2)<<2);
    memset(R+n-m+1,0,(lim-n+m+1)<<2);
    multiply(A,R,m,n-m,R,m-1,false);
    for(reg int i=0;i!=m;++i) R[i] = dec(B[i],R[i]);
}

void mod_power(const int *G,int k,ll t,int *R){
    int f[N],g[N],ig[N];
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    memset(ig,0,sizeof(ig));
    memcpy(ig,G,(k+1)<<2);
    reverse(ig,ig+k+1);
    inverse(ig,k,ig);
    int n = 1,m = 0;
    f[1] = g[0] = 1;
    while(1){
        if(t&1){
            multiply(f,g,n,m,g,n+m,false);
            mod(g,G,ig,n+m,k,g);
            m = min(n+m,k-1);
        }
        t >>= 1;
        if(t==0) break;
        multiply(f,f,n,n,f,n<<1,true);
        mod(f,G,ig,n<<1,k,f);
        n = min(n<<1,k-1);
    }
    memcpy(R,g,k<<2);
}

void poly_pow(const int *f,int n,int k,int *R){
    static int g[N],h[N];
    memset(g,0,sizeof(g)),memset(h,0,sizeof(h));
    memcpy(g,f,(n+1)<<2);
    h[0] = 1;
    int m = 0;
    while(1){
        if(k&1){
            multiply(h,g,n,m,h,n+m,false);
            m += n;
        }
        k >>= 1;
        if(k==0) break;
        multiply(g,g,n,n,g,n<<1,true);
        n <<= 1;
    }
    memcpy(R,h,(m+1)<<2);
}

inline int poww(int a,int t,int m){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%m;
        a = (ll)a*a%m;
        t >>= 1;
    }
    return res;
}

int k,lim;
int a[N],f[N],c[N],F[N],G[N];
ll n;

int special(){
    if(k==1){
        for(int i=1;i<=42;++i) a[42] += a[42-i]*i;
        for(int i=1;i<=42;++i) a[i] = add(a[i],a[i-1]);
        for(int i=43;i;--i) f[i] = dec(f[i],f[i-1]);
    }
    int siz = 42+k;
    matrix A = matrix(siz);
    for(reg int i=0;i!=siz;++i) A.a[i][0] = f[i+1];
    for(reg int i=1;i!=siz;++i) A.a[i-1][i] = 1;
    A = mat_pow(A,n-1);
    int res = 0;
    for(reg int i=0;i!=siz;++i) res = (res+(ll)a[siz-1-i]*A.a[i][siz-1])%p;
    return poww(2,res,p+1);
}

int main(){
    int ans = 0;
    scanf("%lld%d",&n,&k);
    if(n==1){
        putchar('2');
        return 0;
    }
    f[0] = p-1;
    for(reg int i=1;i<=42;++i) a[i-1] = f[i] = i;
    if(k<=1){
        printf("%d",special());
        return 0;
    }
    init((k+42)<<1);
    c[0] = 1,c[1] = p-1;
    poly_pow(c,1,k,c);
    lim = k+42; 
    for(reg int i=0;i<=42;++i)
    for(reg int j=0;j<=k;++j)
        G[i+j] = (G[i+j]+(ll)f[i]*c[j])%p;
    inverse(c,lim-1,c);
    for(reg int i=42;i!=lim;++i)
    for(reg int j=1;j<=42;++j)
        a[i] = (a[i]+(ll)a[i-j]*j)%p; 
    multiply(a,c,lim-1,lim-1,a,lim-1,false); 
    reverse(G,G+lim+1);
    for(reg int i=0;i<=lim;++i) G[i] = G[i]?p-G[i]:0;
    mod_power(G,lim,n-1,F);
    for(reg int i=0;i!=lim;++i) ans = (ans+(ll)F[i]*a[i])%p;   
    printf("%d",poww(2,ans,p+1));
    return 0;
}

ps:其实可以做到 \Theta(k\log k + \log n),需要带个几百倍的常数,这里就不写了(逃