题解 P4191 【[CTSC2010]性能优化】
NaCly_Fish · · 题解
萌新刚学 FFT,然后就看见了这题,,
首先观察题目中给的伪代码,不难发现要求出序列
你要知道一个结论:两个长为
证明很简单,套一下单位根反演就行了,不是这里主要讲的内容。
现在的问题就是怎么快速做给定长为
设
这样就有
代入单位根就是
直接分治做即可。
这里写的是递归版,比较慢;稍微改改,处理出每个数在最后跑到哪个位置(即普通 FFT 中的 rev 数组),一层层合并上去,就是迭代版了。
参考代码(常数巨大):
#pragma GCC optimize (2)
#pragma GCC optimize ("unroll-loops")
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 500003
#define ll long long
#define reg register
using namespace std;
inline void read(int &x){
x = 0;
char c = getchar();
while(c<'0'||c>'9') c = getchar();
while(c>='0'&&c<='9'){
x = (x<<3)+(x<<1)+(c^48);
c = getchar();
}
}
void print(int x){
if(x>9) print(x/10);
putchar(x%10+'0');
}
int p,r;
inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%p;
a = (ll)a*a%p;
t >>= 1;
}
return res;
}
inline int findrt(int x){
static int fac[N];
int cnt = 0,m = x-1;
for(reg int i=2;i*i<=m;++i){
if(m%i!=0) continue;
fac[++cnt] = i;
while(m%i==0) m /= i;
}
if(m>1) fac[++cnt] = m;
for(reg int i=2;i<=x;++i){
bool flag = true;
for(reg int j=1;j<=cnt;++j){
if(power(i,(x-1)/fac[j])!=1) continue;
flag = false;
break;
}
if(flag) return i;
}
return -1;
}
int fac[233]; // it's factor not factorial
int cnt;
inline void getfac(int x){
for(reg int i=2;i*i<=x;++i){
if(x%i!=0) continue;
while(x%i==0){
fac[++cnt] = i;
x /= i;
}
}
if(x>1) fac[++cnt] = x;
}
void dft(int *f,int n,int dep){
int rt[n];
rt[0] = 1,rt[1] = power(r,(p-1)/n);
for(reg int i=2;i!=n;++i) rt[i] = (ll)rt[i-1]*rt[1]%p;
if(n<=64){
int a[n];
memset(a,0,n<<2);
for(reg int i=0;i!=n;++i)
for(reg int j=0;j!=n;++j)
a[i] = (a[i]+(ll)f[j]*rt[i*j%n])%p;
memcpy(f,a,n<<2);
return;
}
int d = fac[dep],lim = n/fac[dep];
int g[d][lim];
for(reg int i=0;i!=d;++i)
for(reg int j=0;j!=lim;++j)
g[i][j] = f[i+j*d];
for(reg int i=0;i!=d;++i) dft(g[i],lim,dep+1);
for(reg int j=0;j!=n;++j){
f[j] = 0;
for(reg int i=0;i!=d;++i)
f[j] = (f[j]+(ll)rt[i*j%n]*g[i][j%lim])%p;
}
}
inline void idft(int *f,int n){
reverse(f+1,f+n);
dft(f,n,1);
int x = power(n,p-2);
for(reg int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p;
}
int n,k;
int a[N],b[N];
int main(){
read(n),read(k);
p = n+1,r = findrt(n+1);
getfac(n);
for(reg int i=0;i!=n;++i) read(a[i]);
for(reg int i=0;i!=n;++i) read(b[i]);
dft(a,n,1),dft(b,n,1);
for(reg int i=0;i!=n;++i) a[i] = (ll)a[i]*power(b[i],k)%p;
idft(a,n);
for(reg int i=0;i!=n;++i) print(a[i]),putchar('\n');
return 0;
}