题解 P7342 【『MdOI R4』Destiny】

· · 题解

教程:如何正确地使用 OEIS

首先当然是考虑每个数被算进答案多少次,设 T_{n,k} \ (k\in[0,n]) 是长为 n+1 的序列中,第 k+1 个数被算的次数。

随便打个暴力,在 OEIS 上可以找到这个数列,然后眼尖的你一眼就看到了递推式:

T_{n,k}=2(T_{n-1,k-1}+T_{n-1,k})-(2-[n=2])T_{n-2,k-1}\ (n\ge 2,k\ge1)

上面没给每行的生成函数,只好我们自己算了,,

这个 [n=2] 的条件有点麻烦,如果假定 T_{0,0} = \dfrac12 就简单很多。设 T_n(x)\left\{ T_{n,k}\right\}_{k=0}^n 的生成函数,就有:

T_n(x)=2(1+x)T_{n-1}(x)-2xT_{n-2}(x)

按二阶递推数列解出来就是

T_n(x)=\frac{(1+x+\sqrt{x^2+1})^{n+1}-(1+x-\sqrt{x^2+1})^{n+1}}{4\sqrt{1+x^2}}

根据数列 a 的构造方法,我们知道答案是

\sum_{i=0}^{n-1}b_{i \bmod k}T_{n-1,i}

而这实际上就是 T_{n-1}(x) \bmod (x^k-1) 与序列 b 的点积。要计算的话,就是维护一个 F(x)=f(x)+\sqrt{1+x^2}g(x) 形式的多项式;而最后一步除 \sqrt{1+x^2} 实际上就是提取其中 g(x) 的系数。

顺便一提,容易证明 (1+x+\sqrt{x^2+1})^n(1+x-\sqrt{x^2+1})^n\sqrt{1+x^2} 部分是互为相反数的,故可以只算一遍快速幂,常数减半。

代码奉上:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 262147
#define ll long long
#define reg register
#define p 998244353
using namespace std;

inline int read(){
    int x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
    return x;
}

inline int add(const int& x,const int& y){ return x+y>=p?x+y-p:x+y; }
inline int dec(const int& x,const int& y){ return x<y?x-y+p:x-y; }
inline int rec(const int& x){ return x>=p?x-p:x; }

inline int power(int a,int t){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%p;
        a = (ll)a*a%p;
        t >>= 1;
    }
    return res;
}

int siz;
int rev[N],rt[N],inv[N],fac[N],ifac[N];

void init(int n){
    int lim = 1;
    while(lim<=n) lim <<= 1,++siz;
    for(reg int i=0;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
    int w = power(3,(p-1)>>siz);
    fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = rt[lim>>1] = 1;
    for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
    for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
    for(reg int i=2;i<=n;++i) fac[i] = (ll)fac[i-1]*i%p;
    ifac[n] = power(fac[n],p-2);
    for(reg int i=n-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
    for(reg int i=2;i<=n;++i) inv[i] = (ll)fac[i-1]*ifac[i]%p;
}

inline void dft(int *f,int n){
    static unsigned long long a[N];
    reg int x,shift = siz-__builtin_ctz(n);
    for(reg int i=0;i!=n;++i) a[rev[i]>>shift] = f[i];
    for(reg int mid=1;mid!=n;mid<<=1)
    for(reg int j=0;j!=n;j+=(mid<<1))
    for(reg int k=0;k!=mid;++k){
        x = a[j|k|mid]*rt[mid|k]%p;
        a[j|k|mid] = a[j|k]+p-x;
        a[j|k] += x;
    }
    for(reg int i=0;i!=n;++i) f[i] = a[i]%p;
}

inline void idft(int *f,int n){
    reverse(f+1,f+n);
    dft(f,n);
    int x = p-((p-1)>>__builtin_ctz(n));
    for(reg int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p;
}

inline int getlen(int n){
    return 1<<(32-__builtin_clz(n));
}

struct poly{
    int a[N],b[N];
    int t;

    inline poly(int _t=0):t(_t){
        memset(a,0,sizeof(a));
        memset(b,0,sizeof(b));
    }
};

inline void getv(int lim,int *v){
    int rt = power(3,(p-1)/lim),x = 1;
    for(reg int i=0;i!=lim;++i){
        v[i] = ((ll)x*x+1)%p;
        x = (ll)x*rt%p;
    }
}

inline poly multiply(poly f,poly g){
    static int A[N],B[N],C[N],D[N],R[N],S[N],v[N];
    int n = f.t,lim = getlen(f.t<<1);
    getv(lim,v); // 这里是线性计算 (1+x^2) 的 dft,以减少此后的一次 idft
    poly h = poly(n);
    memcpy(A,f.a,n<<2),memcpy(B,f.b,n<<2);
    memcpy(C,g.a,n<<2),memcpy(D,g.b,n<<2);
    memset(A+n,0,(lim-n+2)<<2),memset(B+n,0,(lim-n+2)<<2);
    memset(C+n,0,(lim-n+2)<<2),memset(D+n,0,(lim-n+2)<<2);
    dft(A,lim),dft(B,lim),dft(C,lim),dft(D,lim);
    for(reg int i=0;i!=lim;++i){
        R[i] = ((ll)A[i]*C[i]+(ll)B[i]*D[i]%p*v[i])%p;
        S[i] = ((ll)A[i]*D[i]+(ll)B[i]*C[i])%p;
    }
    idft(R,lim),idft(S,lim);
    for(reg int i=0;i!=n;++i){
        h.a[i] = add(R[i],R[i+n]);
        h.b[i] = add(S[i],S[i+n]);
    }
    h.a[0] = add(h.a[0],R[n<<1]);
    h.a[1] = add(h.a[1],R[n<<1|1]);
    return h;
}

inline poly square(poly f){ // 算平方的时候可以减少 dft 次数
    static int A[N],B[N],R[N],S[N],v[N];
    int n = f.t,lim = getlen(f.t<<1);
    getv(lim,v);
    poly h = poly(n);
    memcpy(A,f.a,n<<2),memcpy(B,f.b,n<<2);
    memset(A+n,0,(lim-n+2)<<2),memset(B+n,0,(lim-n+2)<<2);
    dft(A,lim),dft(B,lim);
    for(reg int i=0;i!=lim;++i){
        R[i] = ((ll)A[i]*A[i]+(ll)B[i]*B[i]%p*v[i])%p;
        S[i] = ((ll)A[i]*B[i]<<1)%p;
    }
    idft(R,lim),idft(S,lim);
    for(reg int i=0;i!=n;++i){
        h.a[i] = add(R[i],R[i+n]);
        h.b[i] = add(S[i],S[i+n]);
    }
    h.a[0] = add(h.a[0],R[n<<1]);
    h.a[1] = add(h.a[1],R[n<<1|1]);
    return h;
}

inline poly power(poly f,int n,int t){
    poly g = poly(n);
    g.a[0] = 1;
    while(t){
        if(t&1) g = multiply(f,g);
        t >>= 1;
        if(t==0) break;
        f = square(f);
    }
    return g;
}

int n,k,ans;
poly f;
int a[N];

int main(){
    n = read(),k = read();
    init(k<<1);
    f.t = k;
    f.a[0] = f.a[1] = f.b[0] = 1;
    f = power(f,k,n);
    for(reg int i=0;i!=k;++i) ans = (ans+(ll)f.b[i]*read())%p;
    printf("%lld",(ll)ans*power(2,p-2)%p);
    return 0;   
}