题解 P5075 【[JSOI2012]分零食】
NaCly_Fish · · 题解
怎么这题比较正常的题解都在下面,,
其实就是一道基础生成函数题。
首先,设
考虑其组合意义,不难发现
那么枚举有糖的人数,答案就直接出来了:
由此也可以得到一个常数优化:当
否则就只能倍增算那个多项式的幂(因为模数为合数,如果有更优的做法请告诉我)。
另外就是多项式的长度,由于
加上这些优化,成功跑到最优解第一:
#pragma GCC optimize ("unroll-loops")
#pragma GCC optimize ("Ofast")
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 32774
#define ll long long
#define reg register
#define add(x,y) (x+y>=p?x+y-p:x+y)
#define dec(x,y) (x<y?x-y+p:x-y)
using namespace std;
int p;
int rev[N],rt[N];
int siz;
#define md 998244353
inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%md;
a = (ll)a*a%md;
t >>= 1;
}
return res;
}
void init(int n){
int w,lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
rt[lim>>1] = 1;
w = power(3,(md-1)>>siz);
for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%md;
for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
}
inline void dft(int *f,int lim){
static unsigned long long a[N];
reg int x,shift = siz-__builtin_ctz(lim);
for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
for(reg int mid=1;mid!=lim;mid<<=1)
for(reg int j=0;j!=lim;j+=(mid<<1))
for(reg int k=0;k!=mid;++k){
x = a[j|k|mid]*rt[mid|k]%md;
a[j|k|mid] = a[j|k]+md-x;
a[j|k] += x;
}
for(reg int i=0;i!=lim;++i) f[i] = a[i]%md;
}
inline void idft(int *f,int lim){
reverse(f+1,f+lim);
dft(f,lim);
int x = md-((md-1)>>__builtin_ctz(lim));
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%md;
}
inline int getlen(int n){
return 1<<(32-__builtin_clz(n));
}
void multiply(const int *A,const int *B,int n,int m,int *R,int len){
static int f[N],g[N];
memcpy(f,A,(n+1)<<2),memcpy(g,B,(m+1)<<2);
int lim = getlen(n+m);
memset(f+n+1,0,(lim-n)<<2);
memset(g+m+1,0,(lim-m)<<2);
dft(f,lim),dft(g,lim);
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*g[i]%md;
idft(f,lim);
for(reg int i=0;i<=len;++i) R[i] = f[i]%p;
}
inline void inverse(const int *f,int n,int *R){
static int g[N],h[N];
memset(g,0,getlen(n<<1)<<2);
int lim = 1,top = 0;
int s[30];
while(n){
s[++top] = n;
n >>= 1;
}
g[0] = 1;
while(top--){
n = s[top+1];
while(lim<=(n<<1)) lim <<= 1;
memcpy(h,f,(n+1)<<2);
memset(h+n+1,0,(lim-n)<<2);
multiply(h,g,n,n,h,n);
multiply(h,g,n,n,h,n);
for(reg int i=0;i<=n;++i) g[i] = dec(add(g[i],g[i]),h[i]);
}
memcpy(R,g,(n+1)<<2);
}
inline void power(int *f,int n,int k,int *R){
int g[N];
g[0] = 1;
while(1){
if(k&1) multiply(g,f,n,n,g,n);
k >>= 1;
if(k==0) break;
multiply(f,f,n,n,f,n);
}
memcpy(R,g,(n+1)<<2);
}
int m,A,o,s,u;
int f[N],g[N],h[N];
int main(){
scanf("%d%d%d%d%d%d",&m,&p,&A,&o,&s,&u);
init(m<<1);
for(reg int i=1;i<=m;++i) f[i] = (u+i*(s+o*i))%p;
if(A<m) memcpy(g,f+1,(m-A)<<2);
f[0] = 1;
for(reg int i=1;i<=m;++i) f[i] = f[i]==0?0:p-f[i];
inverse(f,m,f);
if(A>=m){
printf("%d",f[m]);
return 0;
}
power(g,m-A-1,A+1,h);
for(reg int i=m;i>A;--i) h[i] = h[i-A-1];
for(reg int i=A+1;i<=m;++i) h[i] = h[i]==0?0:p-h[i];
int ans = f[m];
for(reg int i=A+1;i<=m;++i)
if(h[i]!=0) ans = (ans+h[i]*f[m-i])%p;
printf("%d",ans);
return 0;
}