题解 P5808 【常系数非齐次线性递推】

· · 题解

首先,题目中是非齐次递推,考虑将其化为齐次的。
以下若无说明,则 f_0=-1

a_n=\sum\limits_{i=0}^mp_in^i+\sum\limits_{i=1}^kf_i \times a_{n-i} a_n=p_mn^m+\sum_{i=0}^{m-1}p_in^i+\sum_{i=1}^kf_i \times a_{n-i} p_mn^m=a_n-\sum_{i=0}^{m-1}p_in^i-\sum_{i=1}^kf_i \times a_{n-i} p_m(n-1)^m=a_{n-1}-\sum_{i=0}^{m-1}p_i(n-1)^i-\sum_{i=1}^kf_i \times a_{n-1-i}

式子推到这里其实就可以收手了;把左式的 (n-1)^m 展开后,除了 p_mn^m 项都移到右边去。然后用右边的式子,取代原式中的 p_mn^m

代回原式上面后,相当于把递推系数做了个差分(当然还要在后面加一项),而和多项式的系数没有关系。

也就是说将递推系数求 m+1 阶差分,就可以得到等价的齐次递推式了。

由于多了 m+1 项,就要再求出 a_k\sim a_{k+m}
原式的递推可以写成这样:
(其中 b 是我们构造的一个序列)

a_n=b_n+\sum_{i=1}^kf_i\times a_{n-i}

为了下面推式子方便,定义 f_0=0

a_n=b_n+\sum\limits_{i=0}^nf_i \times a_{n-i}

这样就是一个明显的卷积,设 \{a_n\}_{n=0}^\infty,\{b_n\}_{n=0}^\infty,\{f_n\}_{n=0}^\infty 的生成函数为 A(x),B(x),F(x),那么就有:

A(x)=B(x)+F(x)A(x) (1-F(x))A(x)=B(x) A(x)=\frac{B(x)}{1-F(x)}

所以只要构造出序列 b,就能多项式求逆算出 a 了。

首先对于 \forall n \ge k,b_n=P(n),多项式多点求值即可。
而对于 n \le k-1

b_n=a_n-\sum\limits_{i=0}^nf_i \times a_{n-i}

这右边还是个卷积,由于这里 n 只能取到 k-1,所以就可以直接算出来了。

前面做的都是准备工作,搞完了之后直接上线性递推板子即可。

参考代码:
(为了可读性没有刻意卡常数)

#pragma GCC optimize ("unroll-loops")
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 131077
#define p 998244353
#define ll long long
#define reg register
#define add(x,y) (x+y>=p?x+y-p:x+y)
#define dec(x,y) (x<y?x-y+p:x-y)
using namespace std;

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

void print(int x){
    if(x>9) print(x/10);
    putchar(x%10+'0');  
}

inline int power(int a,int t){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%p;
        a = (ll)a*a%p;
        t >>= 1; 
    }
    return res;
}

int rt[N],rev[N],inv[N],fac[N],ifac[N];
int siz;

inline int binom(int x,int y){
    if(x<y) return 0;
    return (ll)fac[x]*ifac[y]%p*ifac[x-y]%p;
}

void init(int n){
    int w,lim = 1;
    while(lim<=n) lim <<= 1,++siz;
    for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
    w = power(3,(p-1)>>siz);
    fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = rt[lim>>1] = 1;
    for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
    for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];   
    for(reg int i=2;i<=lim;++i) inv[i] = (ll)(p-p/i)*inv[p%i]%p;
    for(reg int i=2;i<=lim;++i) ifac[i] = fac[i] = (ll)fac[i-1]*i%p;
    ifac[lim] = power(fac[lim],p-2);
    for(reg int i=lim-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
}

inline int getlen(int n){
    return 1<<(32-__builtin_clz(n));    
}

inline void NTT(int *f,int type,int lim){
    if(type==-1) reverse(f+1,f+lim);
    reg int x,shift = siz-__builtin_ctz(lim);
    static int a[N];
    for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
    for(reg int mid=1;mid!=lim;mid<<=1){
        for(reg int j=0;j!=lim;j+=(mid<<1)){
            for(reg int k=0;k!=mid;++k){
                x = (ll)a[j|k|mid]*rt[mid|k]%p;
                a[j|k|mid] = dec(a[j|k],x);
                a[j|k] = add(a[j|k],x);
            }
        }
    }
    memcpy(f,a,lim<<2);
    if(type==1) return;
    x = inv[lim];
    for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%p;
}

void inverse(const int *f,int *R,int n){
    static int g[N],h[N],q[N];
    memset(g,0,getlen(n<<1)+2<<2);
    int lim = 1,top = 0;
    int s[30];
    while(n){
        s[++top] = n;
        n >>= 1;
    }
    g[0] = power(f[0],p-2);
    while(top--){
        n = s[top+1];
        while(lim<=(n<<1)) lim <<= 1;
        memcpy(q,g,(n+1)<<2);
        memcpy(h,f,(n+1)<<2);
        memset(h+n+1,0,(lim-n)<<2);
        NTT(g,1,lim),NTT(h,1,lim);
        for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*g[i]%p*h[i]%p;
        NTT(g,-1,lim);
        for(reg int i=0;i<=n;++i) g[i] = dec(add(q[i],q[i]),g[i]);
        memset(g+n+1,0,(lim-n)<<2);
    }
    memcpy(R,g,(n+1)<<2);
}

void divide(const int *f,const int *g,int n,int m,int *R){
    static int A[N],B[N];
    memcpy(A,f,(n+1)<<2);
    memcpy(B,g,(m+1)<<2);
    reverse(A,A+n+1);
    reverse(B,B+m+1);
    int tt = n-m,lim = getlen((n-m)<<1);
    for(reg int i=tt+1;i!=lim;++i) A[i] = 0;
    for(reg int i=min(m,tt)+1;i!=lim;++i) B[i] = 0;
    inverse(B,B,tt);
    NTT(A,1,lim),NTT(B,1,lim);
    for(reg int i=0;i!=lim;++i) A[i] = (ll)A[i]*B[i]%p;
    NTT(A,-1,lim);
    reverse(A,A+tt+1);
    memcpy(R,A,(tt+1)<<2);
}

void mod(const int *f,const int *g,int n,int m,int *R){
    if(n<m){
        memcpy(R,f,(n+1)<<2);
        return;
    }   
    static int A[N],B[N];
    memcpy(B,f,(n+1)<<2);
    int lim = getlen(n);
    divide(f,g,n,m,R);
    for(int i=0;i<=m;++i) A[i] = g[i];
    for(int i=m+1;i!=lim;++i) A[i] = 0;
    for(int i=n-m+1;i!=lim;++i) R[i] = 0;
    NTT(A,1,lim),NTT(R,1,lim);
    for(reg int i=0;i!=lim;++i) R[i] = (ll)A[i]*R[i]%p;
    NTT(R,-1,lim);
    for(reg int i=0;i!=m;++i) R[i] = dec(B[i],R[i]);
    for(int i=m;i!=lim;++i) R[i] = 0;
}

#define mid ((l+r)>>1)
#define ls (u<<1)
#define rs (u<<1|1)
int bflim;
int *P[N],len[N];

void prepare(int l,int r,int u,const int *a){ //分治乘,多点求值用
    if(l==r){
        len[u] = 1;
        P[u] = new int[2];
        P[u][0] = p-a[l],P[u][1] = 1;
        return;
    }
    prepare(l,mid,ls,a);
    prepare(mid+1,r,rs,a);
    len[u] = r-l+1;
    int lim = getlen(len[u]);
    P[u] = new int[len[u]+1];
    int F[lim+1],G[lim+1];
    memcpy(F,P[ls],(len[ls]+1)<<2);
    memcpy(G,P[rs],(len[rs]+1)<<2);
    if(r-l>bflim){
        memset(F+len[ls]+1,0,(lim-len[ls]+1)<<2);
        memset(G+len[rs]+1,0,(lim-len[rs]+1)<<2);
        NTT(F,1,lim),NTT(G,1,lim);
        for(reg int i=0;i!=lim;++i) F[i] = (ll)F[i]*G[i]%p;
        NTT(F,-1,lim);
        memcpy(P[u],F,(len[u]+1)<<2);
    }else{
        memset(P[u],0,(len[u]+1)<<2);
        for(reg int i=0;i<=len[ls];++i)
        for(reg int j=0;j<=len[rs];++j)
            P[u][i+j] = (P[u][i+j]+(ll)F[i]*G[j])%p;
    }
}

void solve(const int *F,int l,int r,const int *a,int u,int n,int *R){ //多点求值
    if(r-l<=bflim){ //小范围暴力
        ll pw[17];
        int res,x;
        ll s1,s2,s3,s4;
        pw[0] = 1;
        for(reg int j=l;j<=r;++j){
            res = F[n],x = a[j];
            reg int i = 1;
            for(;i<=16;++i) pw[i] = pw[i-1]*x%p;
            i = n-1;
            while(i>=15){
                s1 = res*pw[16]+F[i]*pw[15]+F[i-1]*pw[14]+F[i-2]*pw[13];
                s2 = F[i-3]*pw[12]+F[i-4]*pw[11]+F[i-5]*pw[10]+F[i-6]*pw[9];
                s3 = F[i-7]*pw[8]+F[i-8]*pw[7]+F[i-9]*pw[6]+F[i-10]*pw[5];
                s4 = F[i-11]*pw[4]+F[i-12]*pw[3]+F[i-13]*pw[2]+F[i-14]*x;
                res = ((F[i-15]+s1+s2)%p+s3+s4)%p;
                i -= 16;
            }
            i = (n&15)-1;
            for(;~i;--i) res = ((ll)res*x+F[i])%p;
            R[j] = res;
        }
        return;
    }
    int G[getlen(n<<1)+1];
    memset(G,0,sizeof(G));
    mod(F,P[ls],n,len[ls],G);
    solve(G,l,mid,a,ls,len[ls]-1,R);
    memset(G,0,sizeof(G));
    mod(F,P[rs],n,len[rs],G);
    solve(G,mid+1,r,a,rs,len[rs]-1,R);
}

void evaluation(const int *F,int *a,int n,int m,int *R){
    bflim = log2(m);
    prepare(1,m,1,a);
    solve(F,1,m,a,1,n,R);
}

#undef ls
#undef rs
#undef mid

void multiply(const int *F,const int *G,int n,int m,int len,int *R){
    static int A[N],B[N];
    memcpy(A,F,(n+1)<<2);
    memcpy(B,G,(m+1)<<2);
    int lim = getlen(n+m);
    memset(A+n+1,0,(lim-n+2)<<2);
    memset(B+m+1,0,(lim-m+2)<<2);
    NTT(A,1,lim),NTT(B,1,lim);
    for(reg int i=0;i!=lim;++i) R[i] = (ll)A[i]*B[i]%p;
    NTT(R,-1,lim);
    memset(R+len+1,0,(lim-len+2)<<2);
}

void mod_power(const int *G,int k,int t,int *R){ //多项式快速幂 (模G(x))
    int f[N],g[N];
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    int n = 1,m = 0;
    f[1] = g[0] = 1;
    while(t){
        if(t&1){
            multiply(f,g,n,m,n+m,g);
            mod(g,G,n+m,k,g);
            m = min(n+m,k-1);
        }
        int lim = getlen(n<<1);
        NTT(f,1,lim);
        for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*f[i]%p;
        NTT(f,-1,lim);
        mod(f,G,n<<1,k,f);
        n = min(n<<1,k-1);
        t >>= 1;
    }
    memcpy(R,g,k<<2);
}

int n,m,k,ans,T;
int F[N],G[N],B[N],A[N],a[N],f[N],d[N];

int main(){
    init(100000);
    read(n),read(m),read(k);
    for(reg int i=0;i!=k;++i) read(a[i]);
    for(reg int i=1;i<=k;++i) read(f[i]);
    for(reg int i=0;i<=m;++i) read(G[i]);
    for(reg int i=0;i<=m;++i) d[i+1] = i+k;
    evaluation(G,d,m,m+1,B); //构造B的后半部分
    for(reg int i=k+m;i>=k;--i) B[i] = B[i-k+1];
    for(reg int i=0;i!=k;++i) B[i] = 0;
    multiply(f,a,k,k,k-1,d); //一波卷积求出B的前半部分
    for(reg int i=0;i!=k;++i) B[i] = dec(a[i],d[i]);
    f[0] = p-1;
    for(reg int i=0;i<=k;++i) F[i] = p-f[i];
    inverse(F,F,m+k);
    multiply(F,B,m+k,m+k,m+k,A); //求逆算出A
    for(reg int i=0;i<=m+k;++i) a[i] = A[i];
    memset(F,0,sizeof(F));
    for(reg int i=0;i<=m+1;++i) F[i] = (m+1-i)&1?p-binom(m+1,i):binom(m+1,i); //高阶差分系数
    T = m+k+1;
    multiply(F,f,m+1,k,T,f); //化为齐次递推
    memset(G,0,sizeof(G));
    for(reg int i=0;i<=T;++i) G[T-i] = p-f[i];
    mod_power(G,T,n,F);
    for(reg int i=0;i!=T;++i) ans = (ans+(ll)F[i]*a[i])%p;
    print(ans);
    return 0;   
}