[MtOI2019] T6 Solution
NaCly_Fish
2019-08-24 15:49:00
upd:之前式子有一点锅,现已修复。
这其实是一个很水的套路题。。
首先容易看出来 $f(x,0)$ 是个线性递推的形式,要求的是其 $k$ 阶前缀积。
要求乘积不太好搞,可以对 $2$ 取一下对数,化乘为加。
于是问题转化为:
一个数列 $a$:
$$a_n=n\space(n\le42)$$
$$a_n=\sum\limits_{i=1}^{42}ia_{n-i}\space(n\ge 43)$$
求它 $k$ 阶前缀和的第 $n$ 项。
****
关于线性递推式的高阶前缀和有一个优美的性质。
设数列 $a$ 的递推系数为 $f$,那么在 $f$ 前面加个 $-1$ ,然后做 $k$ 阶差分得到的序列即 $a$ 的 $k$ 阶前缀和的递推式。( 当然要在后面扩展 $k$ 项,同时最后去掉 $-1$ )
在此简短证明一下,设:
$$a_n=\sum\limits_{i=1}^kf_ia_{n-i}$$
$$b_n=\sum\limits_{i=1}^na_i$$
$$b_n=b_{n-1}+a_n=b_{n-1}+\sum\limits_{i=1}^kf_ia_{n-i}$$
$$= b_{n-1}+\sum\limits_{i=1}^kf_i(b_{n-i}-b_{n-i-1})$$
后面的那个求和展开,可以得到很多形如 $(f_i-f_{i-1})b_{n-i}$ 的式子,这就是一个很明显的差分形式,接下来的证明就很容易了。
得到 $k$ 阶前缀和的递推式后,把 $a$ 的 $k$ 阶前缀和的前几项也求出来,然后直接上线性递推板子即可。
不过要注意的是我们刚才取了个 $\log$,所以以上运算都要对 $\color{red} 998244352$ 取模,这需要用到任意模数。$7$ 次 FFT 的做法常数过大,不能通过;需要使用 $4$ 次 FFT 的做法。
还有就是模数不是素数时,算组合数很麻烦,所以直接用倍增多项式快速幂计算高阶差分或前缀和即可。
时间复杂度 $\Theta(k\log^2 k+k\log k\log n)$。
std:
```cpp
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 65539
#define ll long long
#define reg register
#define p 998244352
#define pi 3.141592653589793
using namespace std;
struct complex{
double x,y;
inline complex(double x=0,double y=0):x(x),y(y){}
inline complex operator + (const complex& b) const{ return complex(x+b.x,y+b.y); }
inline complex operator - (const complex& b) const{ return complex(x-b.x,y-b.y); }
inline complex operator * (const complex& b) const{ return complex(x*b.x-y*b.y,x*b.y+y*b.x); }
inline complex operator / (const int& b) const{ return complex(x/b,y/b); }
inline complex operator ~ () const{ return complex(x,-y); }
}rt[N];
struct matrix{
int a[43][43];
int siz;
inline matrix(int _siz=0):siz(_siz){ memset(a,0,sizeof(a)); }
inline matrix operator * (const matrix& b) const{
matrix res = matrix(siz);
for(reg int i=0;i!=siz;++i)
for(reg int j=0;j!=siz;++j)
for(reg int k=0;k!=siz;++k)
res.a[i][j] = (res.a[i][j]+(ll)a[i][k]*b.a[k][j])%p;
return res;
}
};
inline matrix mat_pow(matrix a,ll t){
matrix res = matrix(a.siz);
for(reg int i=0;i!=a.siz;++i) res.a[i][i] = 1;
while(1){
if(t&1) res = res*a;
t >>= 1;
if(t==0) break;
a = a*a;
}
return res;
}
int rev[N];
int siz;
inline int add(int a,int b){ return a+b>=p?a+b-p:a+b; }
inline int dec(int a,int b){ return a<b?a-b+p:a-b; }
inline int getlen(int n){ return 1<<(32-__builtin_clz(n)); }
inline void init(int n){
int lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
rt[lim>>1] = complex(1,0);
for(reg int i=1;i!=(lim>>1);++i) rt[i+(lim>>1)] = complex(cos(pi*2/lim*i),sin(pi*2/lim*i));
for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
}
inline void dft(complex *f,int lim){
static complex a[N];
int shift = siz-__builtin_ctz(lim);
for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
for(reg int mid=1;mid!=lim;mid<<=1)
for(reg int j=0;j!=lim;j+=(mid<<1))
for(reg int k=0;k!=mid;++k){
complex x = a[j|k|mid]*rt[mid|k];
a[j|k|mid] = a[j|k]-x;
a[j|k] = a[j|k]+x;
}
for(reg int i=0;i!=lim;++i) f[i] = a[i];
}
inline void idft(complex *f,int lim){
reverse(f+1,f+lim);
dft(f,lim);
for(reg int i=0;i!=lim;++i) f[i] = f[i]/lim;
}
inline void multiply(const int *A,const int *B,int n,int m,int *R,int len,bool flag){
static complex f[N],g[N],h[N],q[N];
complex t,f0,f1,g0,g1;
ll x,y,z;
int lim = getlen(n+m);
for(reg int i=0;i!=lim;++i){
f[i] = complex(A[i]>>15,A[i]&32767);
g[i] = complex(B[i]>>15,B[i]&32767);
}
dft(f,lim);
if(flag) for(reg int i=0;i!=lim;++i) g[i] = f[i];
else dft(g,lim);
for(reg int i=0;i!=lim;++i){
t = ~f[i?lim-i:0];
f0 = (f[i]-t)*complex(0,-0.5),f1 = (f[i]+t)*0.5;
t = ~g[i?lim-i:0];
g0 = (g[i]-t)*complex(0,-0.5),g1 = (g[i]+t)*0.5;
h[i] = f1*g1;
q[i] = f1*g0 + f0*g1 + f0*g0*complex(0,1);
}
idft(h,lim),idft(q,lim);
for(reg int i=0;i<=len;++i){
x = (ll)(h[i].x+0.5)%p<<30;
y = (ll)(q[i].x+0.5)<<15;
z = q[i].y+0.5;
R[i] = (x+y+z)%p;
}
memset(R+len+1,0,(lim-len)<<2);
}
inline void inverse(const int *f,int n,int *R){
static int g[N],h[N],st[30];
memset(g,0,getlen(n<<1)<<2);
int top = 0,lim =1 ;
while(n){
st[++top] = n;
n >>= 1;
}
g[0] = 1;
while(top--){
n = st[top+1];
while(lim<=(n<<1)) lim <<= 1;
memcpy(h,f,(n+1)<<2);
memset(h+n+1,0,(lim-n)<<2);
multiply(h,g,n,n>>1,h,n,false);
multiply(h,g,n,n>>1,h,n,false);
for(reg int i=(n>>1);i<=n;++i) g[i] = dec(add(g[i],g[i]),h[i]);
}
memcpy(R,g,(n+1)<<2);
}
inline void divide(const int *f,const int *ig,int n,int m,int *R){
static int A[N],B[N];
memcpy(A,f,(n+1)<<2),memcpy(B,ig,(m+1)<<2);
reverse(A,A+n+1);
int tt = n-m,lim = getlen((n-m)<<1);
memset(A+tt+1,0,(lim-tt)<<2);
for(reg int i=min(m,tt)+1;i!=lim;++i) B[i] = 0;
multiply(A,B,tt,tt,R,n-m,false);
reverse(R,R+tt+1);
}
inline void mod(const int *f,const int *g,const int *ig,int n,int m,int *R){
if(n<m) return;
static int A[N],B[N];
memcpy(B,f,(n+1)<<2);
int lim = getlen(n);
divide(f,ig,n,m,R);
memcpy(A,g,(m+1)<<2);
memset(A+m+1,0,(lim-m+2)<<2);
memset(R+n-m+1,0,(lim-n+m+1)<<2);
multiply(A,R,m,n-m,R,m-1,false);
for(reg int i=0;i!=m;++i) R[i] = dec(B[i],R[i]);
}
void mod_power(const int *G,int k,ll t,int *R){
int f[N],g[N],ig[N];
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
memset(ig,0,sizeof(ig));
memcpy(ig,G,(k+1)<<2);
reverse(ig,ig+k+1);
inverse(ig,k,ig);
int n = 1,m = 0;
f[1] = g[0] = 1;
while(1){
if(t&1){
multiply(f,g,n,m,g,n+m,false);
mod(g,G,ig,n+m,k,g);
m = min(n+m,k-1);
}
t >>= 1;
if(t==0) break;
multiply(f,f,n,n,f,n<<1,true);
mod(f,G,ig,n<<1,k,f);
n = min(n<<1,k-1);
}
memcpy(R,g,k<<2);
}
void poly_pow(const int *f,int n,int k,int *R){
static int g[N],h[N];
memset(g,0,sizeof(g)),memset(h,0,sizeof(h));
memcpy(g,f,(n+1)<<2);
h[0] = 1;
int m = 0;
while(1){
if(k&1){
multiply(h,g,n,m,h,n+m,false);
m += n;
}
k >>= 1;
if(k==0) break;
multiply(g,g,n,n,g,n<<1,true);
n <<= 1;
}
memcpy(R,h,(m+1)<<2);
}
inline int poww(int a,int t,int m){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%m;
a = (ll)a*a%m;
t >>= 1;
}
return res;
}
int k,lim;
int a[N],f[N],c[N],F[N],G[N];
ll n;
int special(){
if(k==1){
for(int i=1;i<=42;++i) a[42] += a[42-i]*i;
for(int i=1;i<=42;++i) a[i] = add(a[i],a[i-1]);
for(int i=43;i;--i) f[i] = dec(f[i],f[i-1]);
}
int siz = 42+k;
matrix A = matrix(siz);
for(reg int i=0;i!=siz;++i) A.a[i][0] = f[i+1];
for(reg int i=1;i!=siz;++i) A.a[i-1][i] = 1;
A = mat_pow(A,n-1);
int res = 0;
for(reg int i=0;i!=siz;++i) res = (res+(ll)a[siz-1-i]*A.a[i][siz-1])%p;
return poww(2,res,p+1);
}
int main(){
int ans = 0;
scanf("%lld%d",&n,&k);
if(n==1){
putchar('2');
return 0;
}
f[0] = p-1;
for(reg int i=1;i<=42;++i) a[i-1] = f[i] = i;
if(k<=1){
printf("%d",special());
return 0;
}
init((k+42)<<1);
c[0] = 1,c[1] = p-1;
poly_pow(c,1,k,c);
lim = k+42;
for(reg int i=0;i<=42;++i)
for(reg int j=0;j<=k;++j)
G[i+j] = (G[i+j]+(ll)f[i]*c[j])%p;
inverse(c,lim-1,c);
for(reg int i=42;i!=lim;++i)
for(reg int j=1;j<=42;++j)
a[i] = (a[i]+(ll)a[i-j]*j)%p;
multiply(a,c,lim-1,lim-1,a,lim-1,false);
reverse(G,G+lim+1);
for(reg int i=0;i<=lim;++i) G[i] = G[i]?p-G[i]:0;
mod_power(G,lim,n-1,F);
for(reg int i=0;i!=lim;++i) ans = (ans+(ll)F[i]*a[i])%p;
printf("%d",poww(2,ans,p+1));
return 0;
}
```
ps:其实可以做到 $\Theta(k\log k + \log n)$,需要带个几百倍的常数,这里就不写了(逃