一阶微分方程的迭代求法
NaCly_Fish · · 题解
其实这个科技在 2017 年就被神仙们所熟知了,这里应 兔兔 的要求,造了道模板题,来洛谷普及一下,,求轻喷。
任意一阶微分方程都能写成这种形式:
在洛谷模板题中,
假设已经求得
(后面的项在模
我们希望把带
如果能构造一个
右边那个
注意右边的
啊,好像忘了说
对比一下
直接做就好了。常数奇大无比,完全打不过分治 fft
代码为了好写,没怎么卡常,仅供参考。
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 262147
#define ll long long
#define reg register
#define p 998244353
using namespace std;
inline int add(const int& x,const int& y){ return x+y>=p?x+y-p:x+y; }
inline int dec(const int& x,const int& y){ return x<y?x-y+p:x-y; }
inline void read(int &x){
x = 0;
char c = getchar();
while(c<'0'||c>'9') c = getchar();
while(c>='0'&&c<='9'){
x = (x<<3)+(x<<1)+(c^48);
c = getchar();
}
}
void print(int x){
if(x>9) print(x/10);
putchar(x%10+'0');
}
inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%p;
a = (ll)a*a%p;
t >>= 1;
}
return res;
}
int siz;
int rev[N],rt[N],inv[N];
void init(int n){
int lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(reg int i=0;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
int w = power(3,(p-1)>>siz);
inv[1] = rt[lim>>1] = 1;
for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
for(reg int i=2;i<=n;++i) inv[i] = (ll)(p-p/i)*inv[p%i]%p;
}
inline void dft(int *f,int n){
static unsigned long long a[N];
reg int x,shift = siz-__builtin_ctz(n);
for(reg int i=0;i!=n;++i) a[rev[i]>>shift] = f[i];
for(reg int mid=1;mid!=n;mid<<=1)
for(reg int j=0;j!=n;j+=(mid<<1))
for(reg int k=0;k!=mid;++k){
x = a[j|k|mid]*rt[mid|k]%p;
a[j|k|mid] = a[j|k]+p-x;
a[j|k] += x;
}
for(reg int i=0;i!=n;++i) f[i] = a[i]%p;
}
inline void idft(int *f,int n){
reverse(f+1,f+n);
dft(f,n);
int x = p-((p-1)>>__builtin_ctz(n));
for(reg int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p;
}
inline int getlen(int n){
return 1<<(32-__builtin_clz(n));
}
inline void inverse(const int *f,int n,int *r){
static int g[N],h[N],st[30];
memset(g,0,getlen(n<<1)<<2);
int lim = 1,top = 0;
while(n){
st[++top] = n;
n >>= 1;
}
g[0] = 1;
while(top--){
n = st[top+1];
while(lim<=(n<<1)) lim <<= 1;
memcpy(h,f,(n+1)<<2);
memset(h+n+1,0,(lim-n)<<2);
dft(g,lim),dft(h,lim);
for(reg int i=0;i!=lim;++i) g[i] = g[i]*(2-(ll)g[i]*h[i]%p+p)%p;
idft(g,lim);
memset(g+n+1,0,(lim-n)<<2);
}
memcpy(r,g,(n+1)<<2);
}
inline void log(const int *f,int n,int *r){
static int g[N],h[N];
inverse(f,n,g);
for(reg int i=0;i!=n;++i) h[i] = (ll)f[i+1]*(i+1)%p;
h[n] = 0;
int lim = getlen(n<<1);
memset(g+n+1,0,(lim-n)<<2);
memset(h+n+1,0,(lim-n)<<2);
dft(g,lim),dft(h,lim);
for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*h[i]%p;
idft(g,lim);
for(reg int i=1;i<=n;++i) r[i] = (ll)g[i-1]*inv[i]%p;
r[0] = 0;
}
inline void exp(const int *f,int n,int *r){
static int g[N],h[N],st[30];
memset(g,0,getlen(n<<1)<<2);
int lim = 1,top = 0;
while(n){
st[++top] = n;
n >>= 1;
}
g[0] = 1;
while(top--){
n = st[top+1];
while(lim<=(n<<1)) lim <<= 1;
memcpy(h,g,(n+1)<<2);
memset(h+n+1,0,(lim-n)<<2);
log(g,n,g);
for(reg int i=0;i<=n;++i) g[i] = dec(f[i],g[i]);
g[0] = add(g[0],1);
dft(g,lim),dft(h,lim);
for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*h[i]%p;
idft(g,lim);
memset(g+n+1,0,(lim-n)<<2);
}
memcpy(r,g,(n+1)<<2);
}
inline void multiply(const int *f,const int *g,int n,int m,int *r,int len){
static int a[N],b[N];
int lim = getlen(n+m);
memcpy(a,f,(n+1)<<2),memcpy(b,g,(m+1)<<2);
memset(a+n+1,0,(lim-n)<<2),memset(b+m+1,0,(lim-m)<<2);
dft(a,lim),dft(b,lim);
for(reg int i=0;i!=lim;++i) a[i] = (ll)a[i]*b[i]%p;
idft(a,lim);
memcpy(r,a,(len+1)<<2);
}
inline void solve(const int *a,const int *b,int n,int *r){
static int f[N],ef[N],gf[N],dgf[N],h[N],q[N],st[30]; // P(x) in this code is h(x)
int top = 0,lim = getlen(n<<1)<<1;
memset(f,0,lim<<2);
while(n){
st[++top] = n;
n >>= 1;
}
f[0] = 1;
while(top--){
n = st[top+1];
memcpy(ef+1,f+1,n<<2);
ef[0] = 0;
exp(ef,n,ef);
multiply(ef,a,n,n,dgf,n);
for(reg int i=0;i<=n;++i) gf[i] = add(dgf[i],b[i]);
for(reg int i=0;i<=n;++i) h[i] = dgf[i]==0?0:p-dgf[i];
for(reg int i=n;i;--i) h[i] = (ll)h[i-1]*inv[i]%p;
h[0] = 0;
exp(h,n,h);
multiply(dgf,f,n,n,q,n);
for(reg int i=0;i<=n;++i) q[i] = dec(gf[i],q[i]);
multiply(q,h,n,n,q,n);
for(reg int i=n;i;--i) q[i] = (ll)q[i-1]*inv[i]%p;
q[0] = 1;
inverse(h,n,h);
multiply(q,h,n,n,q,n);
memcpy(f+1,q+1,n<<2);
}
memcpy(r,f,(n+1)<<2);
}
int a[N],b[N],f[N];
int n;
int main(){
read(n);
init(n<<1);
for(reg int i=0;i<=n;++i) read(a[i]);
for(reg int i=0;i<=n;++i) read(b[i]);
solve(a,b,n,f);
for(reg int i=0;i<=n;++i) print(f[i]),putchar(' ');
return 0;
}