P4723 【模板】常系数齐次线性递推 题解

· · 题解

这是一篇为数不多用 LSB-first 算法的题解,个人认为是全谷讲这个最详细的。

算法优势:好写,只用到了 \text{NTT} 这一个板子。而且速度比其他算法快(一般情况下)。

前置知识-- Bostan-mori 算法

我们有一个生成函数 H(x)=\dfrac{F(x)}{G(x)} 。其中 \deg(F),\deg(G)\le k\le32000。要求其第 n\le10^9 项。即求 [x^n]\dfrac{F(x)}{G(x)}

推柿子:[x^n]\dfrac{F(x)}{G(x)}=[x^n]\dfrac{F(x)G(-x)}{G(x)G(-x)}=[x^n]\dfrac{F_0(x^2)+xF_1(x^2)}{G_0(x^2)}=\begin{cases}[x^{n/2}]\dfrac{F_0(x)}{G_0(x)}(2\mid n)\\\\ [x^{n/2}]\dfrac{F_1(x)}{G_0(x)}(2\nmid n)\end{cases}

[x^0]\dfrac{F(x)}{G(x)}=\dfrac{[x^0]F(x)}{[x^0]G(x)}。于是就可以求啦!

有一个细节:就是 G(-x) 要和两个多项式分别做卷积,于是要开两个数组(下面的 c,d 数组)分别和两个做卷积,不能只用一个数组分别和两个做卷积。

代码:

inline int genf(int *a,int *b,int n,int m,int k)
{
    static int c[N],d[N];
    for(;k;k>>=1)
    {
        memset(c,0,sizeof(c));for(int i=0;i<=m;i++) c[i]=b[i];
        for(int i=1;i<=m;i+=2) c[i]=(mod-c[i])%mod;
        memset(d,0,sizeof(d));for(int i=0;i<=m;i++) d[i]=c[i];
        NTT(a,c,n,m);NTT(b,d,m,m);int j=0;
        for(int i=(k&1);i<=n+m;i+=2,j++) a[i/2]=a[i];n=j-1;j=0;
        for(int i=0;i<=2*m;i+=2,j++) b[i/2]=b[i];m=j-1;
        for(int i=n+1;i<=N-5;i++) a[i]=0;
        for(int i=m+1;i<=N-5;i++) b[i]=0;
    }
    return 1ll*a[0]*ksm(b[0],mod-2)%mod;
}

LSB-first 算法

其实内核不难发现。只是名字高端一点。

已知:a_{0,1,...,k-1},f_{1,2,...,k-1},a_n=\sum\limits_{i=1}^{k}f_i \times a_{n-i}。要求 a 的生成函数。

可以得出 a=a\times\sum\limits_{i=1}^k f_ix^i +t(x)(\deg(t)\le k-1)。其中 a 是原来的数列,t(x)=\sum\limits_{i=0}^{k-1}a_ix^i-\sum\limits_{i=0}^{k-2}a_i\sum\limits_{j=1}^{k-1-i}f_jx^{i+j}。是一个幂次小于 k-1 的一个多项式(很难求,于是我们不需要知道它是多少)。

于是 a=\dfrac{t(x)}{T(x)}(T(x)=1-\sum\limits_{i=1}^k f_ix^i)

由于我们知道 a 的前 k-1 项 , T(x) 又有足足 k 次,于是 t(x) 的前 k-1 项可以通过 A\times T 做一个卷积(其中 Aa 的前 k 项构成的多项式)模 x^k 唯一确定。

最后套用 Bostan-mori 算法计算这个生成函数的第 n 项即可,复杂度 O(k\log k\log n)

代码:

//落谷 P4723
//https://www.luogu.com.cn/problem/P4723
#include<bits/stdc++.h>
#define LL long long
#define fr(x) freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);
using namespace std;
const int mod=998244353,N=4e5+5;
int n,k,a[N],b[N],f[N],w[N],mmax;
inline int rd()
{
    int x=0,zf=1;
    char ch=getchar();
    while(ch<'0'||ch>'9') (ch=='-')and(zf=-1),ch=getchar();
    while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x*zf;
}
inline void wr(int x)
{
    if(x==0) return putchar('0'),putchar(' '),void();
    int num[35],len=0;
    while(x) num[++len]=x%10,x/=10;
    for(int i=len;i>=1;i--) putchar(num[i]+'0');
    putchar(' ');
}
inline int bger(int x){return x|=x>>1,x|=x>>2,x|=x>>4,x|=x>>8,x|=x>>16,x+1;}
inline int md(int x){return x>=mod?x-mod:x;}
inline int ksm(int x,int p){int s=1;for(;p;(p&1)&&(s=1ll*s*x%mod),x=1ll*x*x%mod,p>>=1);return s;}
inline void init(int mmax)
{
    for(int i=1,j,k;i<mmax;i<<=1)
        for(w[j=i]=1,k=ksm(3,(mod-1)/(i<<1)),j++;j<(i<<1);j++)
            w[j]=1ll*w[j-1]*k%mod;
}
inline void DNT(int *a,int mmax)
{
    for(int i,j,k=mmax>>1,L,*W,*x,*y,z;k;k>>=1)
        for(L=k<<1,i=0;i<mmax;i+=L)
            for(j=0,W=w+k,x=a+i,y=x+k;j<k;j++,W++,x++,y++)
                *y=1ll*(*x+mod-(z=*y))* *W%mod,*x=md(*x+z);
}
inline void IDNT(int *a,int mmax)
{
    for(int i,j,k=1,L,*W,*x,*y,z;k<mmax;k<<=1)
        for(L=k<<1,i=0;i<mmax;i+=L)
            for(j=0,W=w+k,x=a+i,y=x+k;j<k;j++,W++,x++,y++)
                z=1ll* *W* *y%mod,*y=md(*x+mod-z),*x=md(*x+z);
    reverse(a+1,a+mmax);
    for(int inv=ksm(mmax,mod-2),i=0;i<mmax;i++) a[i]=1ll*a[i]*inv%mod;
}
inline void NTT(int *a,int *b,int n,int m)
{
    mmax=bger(n+m);init(mmax);
    DNT(a,mmax);DNT(b,mmax);
    for(int i=0;i<mmax;i++) a[i]=1ll*a[i]*b[i]%mod;
    IDNT(a,mmax);
}
inline int genf(int *a,int *b,int n,int m,int k)
{
    static int c[N],d[N];
    for(;k;k>>=1)
    {
        memset(c,0,sizeof(c));for(int i=0;i<=m;i++) c[i]=b[i];
        for(int i=1;i<=m;i+=2) c[i]=md(mod-c[i]);
        memset(d,0,sizeof(d));for(int i=0;i<=m;i++) d[i]=c[i];
        NTT(a,c,n,m);NTT(b,d,m,m);int j=0;
        for(int i=(k&1);i<=n+m;i+=2,j++) a[i/2]=a[i];n=j-1;j=0;
        for(int i=0;i<=2*m;i+=2,j++) b[i/2]=b[i];m=j-1;
        for(int i=n+1;i<=N-5;i++) a[i]=0;
        for(int i=m+1;i<=N-5;i++) b[i]=0;
    }
    return 1ll*a[0]*ksm(b[0],mod-2)%mod;
}
inline int rec(int *f,int *a,int n,int k)
{
    for(int i=1;i<=k;i++) f[i]=md(mod-f[i]);f[0]=1;
    static int c[N];memset(c,0,sizeof(c));
    for(int i=0;i<=k;i++) c[i]=f[i];
    NTT(c,a,k,k);for(int i=k;i<=N-5;i++) c[i]=0;
    return genf(c,f,k,k,n);
}
int main()
{
    n=rd();k=rd();
    for(int i=1;i<=k;i++) f[i]=md(rd()%mod+mod);
    for(int i=0;i<k;i++) a[i]=md(rd()%mod+mod);
    wr(rec(f,a,n,k));
    return 0;
}