P10375 O(n log^2 n) 优化做法
NaCly_Fish · · 题解
先来写个时间复杂度
首先我们有一个朴素的 DP 就是
有初始值
要考虑优化的话,比较直接的想法就是按 行/列 建立生成函数。如果要按行做的话,哈哈,那你就掉沟里了。
这题比较好的做法是按列来做,设
最终答案只和
其中
这个东西显然可以分治来计算。设
然后就能根据下式来分治计算:
注意将幂级数表示为分式的形式,这样分子和分母的度数都是
给个答案对
#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
#include<algorithm>
#define N 524292
#define p 998244353
#define ll long long
using namespace std;
inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%p;
a = (ll)a*a%p;
t >>= 1;
}
return res;
}
int siz;
int rev[N],rt[N];
void init(int n){
int lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(int i=0;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
int w = power(3,(p-1)>>siz);
rt[lim>>1] = 1;
for(int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
for(int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
}
inline void dft(int *f,int n){
static unsigned long long a[N];
int x,shift = siz-__builtin_ctz(n);
for(int i=0;i!=n;++i) a[rev[i]>>shift] = f[i];
for(int mid=1;mid!=n;mid<<=1)
for(int j=0;j!=n;j+=(mid<<1))
for(int k=0;k!=mid;++k){
x = a[j|k|mid]*rt[mid|k]%p;
a[j|k|mid] = a[j|k]+p-x;
a[j|k] += x;
}
for(int i=0;i!=n;++i) f[i] = a[i]%p;
}
inline void idft(int *f,int n){
reverse(f+1,f+n);
dft(f,n);
int x = p-(p-1)/n;
for(int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p;
}
inline int getlen(int n){
return 1<<(32-__builtin_clz(n));
}
inline void _inv(const int *f,int n,int *r){
static int g[N],h[N],st[30];
memset(g,0,getlen(n<<1)<<2);
int lim = 1,top = 0;
while(n){
st[++top] = n;
n >>= 1;
}
g[0] = power(f[0],p-2);
while(top--){
n = st[top+1];
while(lim<=(n<<1)) lim <<= 1;
memcpy(h,f,(n+1)<<2);
memset(h+n+1,0,(lim-n)<<2);
dft(g,lim),dft(h,lim);
for(int i=0;i!=lim;++i) g[i] = g[i]*(2-(ll)g[i]*h[i]%p+p)%p;
idft(g,lim);
memset(g+n+1,0,(lim-n)<<2);
}
memcpy(r,g,(n+1)<<2);
}
struct poly{
vector<int> a;
inline int operator [] (const int& x) const{ return x<a.size()?a[x]:0; }
inline int& operator [] (const int& x){ return a[x]; }
inline int deg() const{ return a.size()-1; }
inline void resize(int n){ a.resize(n+1); }
inline poly inverse(){
static int f[N];
int n = a.size()-1;
for(int i=0;i<=n;++i) f[i] = a[i];
_inv(f,n,f);
poly res;
res.resize(n);
memcpy(res.a.begin().base(),f,(n+1)<<2);
return res;
}
};
inline bool operator < (const poly& f,const poly& g){ return f.deg() > g.deg(); }
inline poly operator * (const poly& f,const poly& g){
static int A[N],B[N];
int n = f.deg(),m = g.deg();
poly res;
res.resize(n+m);
if(n<=4||m<=4){
for(int i=0;i<=n;++i)
for(int j=0;j<=m;++j)
res[i+j] = (res[i+j] + (ll)f[i]*g[j])%p;
}else{
memcpy(A,f.a.begin().base(),(n+1)<<2),memcpy(B,g.a.begin().base(),(m+1)<<2);
int lim = 1<<(32-__builtin_clz(n+m));
memset(A+n+1,0,(lim-n)<<2),memset(B+m+1,0,(lim-m)<<2);
dft(A,lim),dft(B,lim);
for(int i=0;i!=lim;++i) A[i] = (ll)A[i]*B[i]%p;
idft(A,lim);
memcpy(res.a.begin().base(),A,(n+m+1)<<2);
}
return res;
}
inline poly operator + (const poly& f,const poly& g){
int n = max(f.deg(),g.deg());
poly res;
res.resize(n);
for(int i=0;i<=n;++i) res[i] = (f[i]+g[i])%p;
return res;
}
int pd[N];
int n,m;
void prod(int l,int r,int u){
if(l==r){
pd[u] = (ll)l*(m-l+1+p)%p;
return;
}
int mid = (l+r)/2;
prod(l,mid,u<<1);
prod(mid+1,r,u<<1|1);
pd[u] = (ll)pd[u<<1]*pd[u<<1|1]%p;
}
pair<poly,poly> solve(int l,int r,int u){
if(l==r){
poly P,Q;
P.resize(2), Q.resize(2);
P[0] = P[1] = 0, P[2] = (ll)l*(m-l+1+p)%p;
Q[0] = 1, Q[1] = p-m, Q[2] = (ll)l*(m-l+p)%p;
return make_pair(P,Q);
}
int mid = (l+r)/2;
pair<poly,poly> L = solve(l,mid,u<<1);
pair<poly,poly> R = solve(mid+1,r,u<<1|1);
L.first = L.first * R.second;
L.second = L.second * R.second;
for(int i=0;i<=R.first.deg();++i) R.first[i] = (ll)R.first[i]*pd[u<<1]%p;
int k = (mid-l+1)*2;
R.first.resize(R.first.deg() + k);
for(int i=R.first.deg();i>=k;--i) R.first[i] = R.first[i-k];
for(int i=k-1;i>=0;--i) R.first[i] = 0;
return make_pair(L.first + R.first, L.second);
}
int main(){
scanf("%d%d",&n,&m);
m %= p;
init(n*2);
prod(1,n/2,1);
pair<poly,poly> res = solve(1,n/2,1);
poly f = res.first, g = res.second;
f = f * g.inverse();
printf("%d",f[n]);
return 0;
}