题解 P5808 【常系数非齐次线性递推】
NaCly_Fish · · 题解
首先,题目中是非齐次递推,考虑将其化为齐次的。
以下若无说明,则
式子推到这里其实就可以收手了;把左式的
代回原式上面后,相当于把递推系数做了个差分(当然还要在后面加一项),而和多项式的系数没有关系。
也就是说将递推系数求
由于多了
原式的递推可以写成这样:
(其中
为了下面推式子方便,定义
这样就是一个明显的卷积,设
所以只要构造出序列
首先对于
而对于
这右边还是个卷积,由于这里
前面做的都是准备工作,搞完了之后直接上线性递推板子即可。
参考代码:
(为了可读性没有刻意卡常数)
#pragma GCC optimize ("unroll-loops")
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 131077
#define p 998244353
#define ll long long
#define reg register
#define add(x,y) (x+y>=p?x+y-p:x+y)
#define dec(x,y) (x<y?x-y+p:x-y)
using namespace std;
inline void read(int &x){
x = 0;
char c = getchar();
while(c<'0'||c>'9') c = getchar();
while(c>='0'&&c<='9'){
x = (x<<3)+(x<<1)+(c^48);
c = getchar();
}
}
void print(int x){
if(x>9) print(x/10);
putchar(x%10+'0');
}
inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%p;
a = (ll)a*a%p;
t >>= 1;
}
return res;
}
int rt[N],rev[N],inv[N],fac[N],ifac[N];
int siz;
inline int binom(int x,int y){
if(x<y) return 0;
return (ll)fac[x]*ifac[y]%p*ifac[x-y]%p;
}
void init(int n){
int w,lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
w = power(3,(p-1)>>siz);
fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = rt[lim>>1] = 1;
for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
for(reg int i=2;i<=lim;++i) inv[i] = (ll)(p-p/i)*inv[p%i]%p;
for(reg int i=2;i<=lim;++i) ifac[i] = fac[i] = (ll)fac[i-1]*i%p;
ifac[lim] = power(fac[lim],p-2);
for(reg int i=lim-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
}
inline int getlen(int n){
return 1<<(32-__builtin_clz(n));
}
inline void NTT(int *f,int type,int lim){
if(type==-1) reverse(f+1,f+lim);
reg int x,shift = siz-__builtin_ctz(lim);
static int a[N];
for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
for(reg int mid=1;mid!=lim;mid<<=1){
for(reg int j=0;j!=lim;j+=(mid<<1)){
for(reg int k=0;k!=mid;++k){
x = (ll)a[j|k|mid]*rt[mid|k]%p;
a[j|k|mid] = dec(a[j|k],x);
a[j|k] = add(a[j|k],x);
}
}
}
memcpy(f,a,lim<<2);
if(type==1) return;
x = inv[lim];
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%p;
}
void inverse(const int *f,int *R,int n){
static int g[N],h[N],q[N];
memset(g,0,getlen(n<<1)+2<<2);
int lim = 1,top = 0;
int s[30];
while(n){
s[++top] = n;
n >>= 1;
}
g[0] = power(f[0],p-2);
while(top--){
n = s[top+1];
while(lim<=(n<<1)) lim <<= 1;
memcpy(q,g,(n+1)<<2);
memcpy(h,f,(n+1)<<2);
memset(h+n+1,0,(lim-n)<<2);
NTT(g,1,lim),NTT(h,1,lim);
for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*g[i]%p*h[i]%p;
NTT(g,-1,lim);
for(reg int i=0;i<=n;++i) g[i] = dec(add(q[i],q[i]),g[i]);
memset(g+n+1,0,(lim-n)<<2);
}
memcpy(R,g,(n+1)<<2);
}
void divide(const int *f,const int *g,int n,int m,int *R){
static int A[N],B[N];
memcpy(A,f,(n+1)<<2);
memcpy(B,g,(m+1)<<2);
reverse(A,A+n+1);
reverse(B,B+m+1);
int tt = n-m,lim = getlen((n-m)<<1);
for(reg int i=tt+1;i!=lim;++i) A[i] = 0;
for(reg int i=min(m,tt)+1;i!=lim;++i) B[i] = 0;
inverse(B,B,tt);
NTT(A,1,lim),NTT(B,1,lim);
for(reg int i=0;i!=lim;++i) A[i] = (ll)A[i]*B[i]%p;
NTT(A,-1,lim);
reverse(A,A+tt+1);
memcpy(R,A,(tt+1)<<2);
}
void mod(const int *f,const int *g,int n,int m,int *R){
if(n<m){
memcpy(R,f,(n+1)<<2);
return;
}
static int A[N],B[N];
memcpy(B,f,(n+1)<<2);
int lim = getlen(n);
divide(f,g,n,m,R);
for(int i=0;i<=m;++i) A[i] = g[i];
for(int i=m+1;i!=lim;++i) A[i] = 0;
for(int i=n-m+1;i!=lim;++i) R[i] = 0;
NTT(A,1,lim),NTT(R,1,lim);
for(reg int i=0;i!=lim;++i) R[i] = (ll)A[i]*R[i]%p;
NTT(R,-1,lim);
for(reg int i=0;i!=m;++i) R[i] = dec(B[i],R[i]);
for(int i=m;i!=lim;++i) R[i] = 0;
}
#define mid ((l+r)>>1)
#define ls (u<<1)
#define rs (u<<1|1)
int bflim;
int *P[N],len[N];
void prepare(int l,int r,int u,const int *a){ //分治乘,多点求值用
if(l==r){
len[u] = 1;
P[u] = new int[2];
P[u][0] = p-a[l],P[u][1] = 1;
return;
}
prepare(l,mid,ls,a);
prepare(mid+1,r,rs,a);
len[u] = r-l+1;
int lim = getlen(len[u]);
P[u] = new int[len[u]+1];
int F[lim+1],G[lim+1];
memcpy(F,P[ls],(len[ls]+1)<<2);
memcpy(G,P[rs],(len[rs]+1)<<2);
if(r-l>bflim){
memset(F+len[ls]+1,0,(lim-len[ls]+1)<<2);
memset(G+len[rs]+1,0,(lim-len[rs]+1)<<2);
NTT(F,1,lim),NTT(G,1,lim);
for(reg int i=0;i!=lim;++i) F[i] = (ll)F[i]*G[i]%p;
NTT(F,-1,lim);
memcpy(P[u],F,(len[u]+1)<<2);
}else{
memset(P[u],0,(len[u]+1)<<2);
for(reg int i=0;i<=len[ls];++i)
for(reg int j=0;j<=len[rs];++j)
P[u][i+j] = (P[u][i+j]+(ll)F[i]*G[j])%p;
}
}
void solve(const int *F,int l,int r,const int *a,int u,int n,int *R){ //多点求值
if(r-l<=bflim){ //小范围暴力
ll pw[17];
int res,x;
ll s1,s2,s3,s4;
pw[0] = 1;
for(reg int j=l;j<=r;++j){
res = F[n],x = a[j];
reg int i = 1;
for(;i<=16;++i) pw[i] = pw[i-1]*x%p;
i = n-1;
while(i>=15){
s1 = res*pw[16]+F[i]*pw[15]+F[i-1]*pw[14]+F[i-2]*pw[13];
s2 = F[i-3]*pw[12]+F[i-4]*pw[11]+F[i-5]*pw[10]+F[i-6]*pw[9];
s3 = F[i-7]*pw[8]+F[i-8]*pw[7]+F[i-9]*pw[6]+F[i-10]*pw[5];
s4 = F[i-11]*pw[4]+F[i-12]*pw[3]+F[i-13]*pw[2]+F[i-14]*x;
res = ((F[i-15]+s1+s2)%p+s3+s4)%p;
i -= 16;
}
i = (n&15)-1;
for(;~i;--i) res = ((ll)res*x+F[i])%p;
R[j] = res;
}
return;
}
int G[getlen(n<<1)+1];
memset(G,0,sizeof(G));
mod(F,P[ls],n,len[ls],G);
solve(G,l,mid,a,ls,len[ls]-1,R);
memset(G,0,sizeof(G));
mod(F,P[rs],n,len[rs],G);
solve(G,mid+1,r,a,rs,len[rs]-1,R);
}
void evaluation(const int *F,int *a,int n,int m,int *R){
bflim = log2(m);
prepare(1,m,1,a);
solve(F,1,m,a,1,n,R);
}
#undef ls
#undef rs
#undef mid
void multiply(const int *F,const int *G,int n,int m,int len,int *R){
static int A[N],B[N];
memcpy(A,F,(n+1)<<2);
memcpy(B,G,(m+1)<<2);
int lim = getlen(n+m);
memset(A+n+1,0,(lim-n+2)<<2);
memset(B+m+1,0,(lim-m+2)<<2);
NTT(A,1,lim),NTT(B,1,lim);
for(reg int i=0;i!=lim;++i) R[i] = (ll)A[i]*B[i]%p;
NTT(R,-1,lim);
memset(R+len+1,0,(lim-len+2)<<2);
}
void mod_power(const int *G,int k,int t,int *R){ //多项式快速幂 (模G(x))
int f[N],g[N];
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
int n = 1,m = 0;
f[1] = g[0] = 1;
while(t){
if(t&1){
multiply(f,g,n,m,n+m,g);
mod(g,G,n+m,k,g);
m = min(n+m,k-1);
}
int lim = getlen(n<<1);
NTT(f,1,lim);
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*f[i]%p;
NTT(f,-1,lim);
mod(f,G,n<<1,k,f);
n = min(n<<1,k-1);
t >>= 1;
}
memcpy(R,g,k<<2);
}
int n,m,k,ans,T;
int F[N],G[N],B[N],A[N],a[N],f[N],d[N];
int main(){
init(100000);
read(n),read(m),read(k);
for(reg int i=0;i!=k;++i) read(a[i]);
for(reg int i=1;i<=k;++i) read(f[i]);
for(reg int i=0;i<=m;++i) read(G[i]);
for(reg int i=0;i<=m;++i) d[i+1] = i+k;
evaluation(G,d,m,m+1,B); //构造B的后半部分
for(reg int i=k+m;i>=k;--i) B[i] = B[i-k+1];
for(reg int i=0;i!=k;++i) B[i] = 0;
multiply(f,a,k,k,k-1,d); //一波卷积求出B的前半部分
for(reg int i=0;i!=k;++i) B[i] = dec(a[i],d[i]);
f[0] = p-1;
for(reg int i=0;i<=k;++i) F[i] = p-f[i];
inverse(F,F,m+k);
multiply(F,B,m+k,m+k,m+k,A); //求逆算出A
for(reg int i=0;i<=m+k;++i) a[i] = A[i];
memset(F,0,sizeof(F));
for(reg int i=0;i<=m+1;++i) F[i] = (m+1-i)&1?p-binom(m+1,i):binom(m+1,i); //高阶差分系数
T = m+k+1;
multiply(F,f,m+1,k,T,f); //化为齐次递推
memset(G,0,sizeof(G));
for(reg int i=0;i<=T;++i) G[T-i] = p-f[i];
mod_power(G,T,n,F);
for(reg int i=0;i!=T;++i) ans = (ans+(ll)F[i]*a[i])%p;
print(ans);
return 0;
}