题解 P4191 【[CTSC2010]性能优化】

· · 题解

萌新刚学 FFT,然后就看见了这题,,

首先观察题目中给的伪代码,不难发现要求出序列 a 与「序列 b 的循环卷积 c 次幂」的循环卷积。

你要知道一个结论:两个长为 n 的序列分别 DFT 后,对应项乘起来再 IDFT,就是两个序列做长度为 n 的循环卷积的结果。

证明很简单,套一下单位根反演就行了,不是这里主要讲的内容。

现在的问题就是怎么快速做给定长为 n 的 DFT。虽然 Bluestein 可以做,但常数过大;不妨考虑题目中给出的条件:n 的质因子都不超过 10

dn 的一个质因子,对 A(x) 做 DFT 时,可以分成 d 个多项式:

A_i(x)=\sum_{j=0}^{n/d-1} a_{i+jd}x^j

这样就有

A(x)=\sum_{i=0}^{d-1}x^iA_i(x^d)

代入单位根就是

A(\omega_n^j)=\sum_{i=0}^{d-1}\omega_n^{ij}A_i(\omega_n^{jd}) A(\omega_n^j)=\sum_{i=0}^{d-1}\omega_n^{ij}A_i(\omega_{n/d}^j)

直接分治做即可。

这里写的是递归版,比较慢;稍微改改,处理出每个数在最后跑到哪个位置(即普通 FFT 中的 rev 数组),一层层合并上去,就是迭代版了。

参考代码(常数巨大):

#pragma GCC optimize (2)
#pragma GCC optimize ("unroll-loops")
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 500003
#define ll long long
#define reg register
using namespace std;

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

void print(int x){
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

int p,r;

inline int power(int a,int t){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%p;
        a = (ll)a*a%p;
        t >>= 1;
    }
    return res;
}

inline int findrt(int x){
    static int fac[N];
    int cnt = 0,m = x-1;
    for(reg int i=2;i*i<=m;++i){
        if(m%i!=0) continue;
        fac[++cnt] = i;
        while(m%i==0) m /= i;
    }
    if(m>1) fac[++cnt] = m;
    for(reg int i=2;i<=x;++i){
        bool flag = true;
        for(reg int j=1;j<=cnt;++j){
            if(power(i,(x-1)/fac[j])!=1) continue;
            flag = false;
            break;
        }
        if(flag) return i;
    }
    return -1;
}

int fac[233]; // it's factor not factorial
int cnt;

inline void getfac(int x){
    for(reg int i=2;i*i<=x;++i){
        if(x%i!=0) continue;
        while(x%i==0){
            fac[++cnt] = i;
            x /= i;
        }
    }
    if(x>1) fac[++cnt] = x;
}

void dft(int *f,int n,int dep){
    int rt[n];
    rt[0] = 1,rt[1] = power(r,(p-1)/n);
    for(reg int i=2;i!=n;++i) rt[i] = (ll)rt[i-1]*rt[1]%p;
    if(n<=64){
        int a[n];
        memset(a,0,n<<2);
        for(reg int i=0;i!=n;++i)
        for(reg int j=0;j!=n;++j)
            a[i] = (a[i]+(ll)f[j]*rt[i*j%n])%p;
        memcpy(f,a,n<<2);
        return;
    }
    int d = fac[dep],lim = n/fac[dep];
    int g[d][lim];
    for(reg int i=0;i!=d;++i)
    for(reg int j=0;j!=lim;++j)
        g[i][j] = f[i+j*d];
    for(reg int i=0;i!=d;++i) dft(g[i],lim,dep+1);    
    for(reg int j=0;j!=n;++j){
        f[j] = 0;
        for(reg int i=0;i!=d;++i)
            f[j] = (f[j]+(ll)rt[i*j%n]*g[i][j%lim])%p;
    }
}

inline void idft(int *f,int n){
    reverse(f+1,f+n);
    dft(f,n,1);
    int x = power(n,p-2);
    for(reg int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p;
}

int n,k;
int a[N],b[N];

int main(){ 
    read(n),read(k);
    p = n+1,r = findrt(n+1);
    getfac(n);
    for(reg int i=0;i!=n;++i) read(a[i]);
    for(reg int i=0;i!=n;++i) read(b[i]);
    dft(a,n,1),dft(b,n,1);
    for(reg int i=0;i!=n;++i) a[i] = (ll)a[i]*power(b[i],k)%p;
    idft(a,n);
    for(reg int i=0;i!=n;++i) print(a[i]),putchar('\n');
    return 0;   
}