题解:P12771 [POI 2018 R3] 多项式 Polynomial

· · 题解

给定一个 n 项多项式,n2 的幂次,给定 q 和模数 m,求该多项式在点 q^1, q^2, \ldots, q^n 处的取值。特别的,q^n\equiv 1\pmod m

分析

第一眼看到题目,首先想到任意模数的 Chirp Z-Transform,但是发现模数太任意了,可以很小,导致组合数没有逆元,遂放弃。

接着我们考虑能否利用 q^n\equiv 1\pmod m 这个类似单位根的性质,需要注意的是,可能是 q^d\equiv 1\pmod md \mid n,这时答案每 d 个为一循环,我们也是每 d 项算一次 n 个点值并累加,然后循环 n/d 次输出,即如下代码段:

pw[0]=1;
for(int i=1;i<=n;i++)pw[i]=1ll*pw[i-1]*q%m;
for(d=1;d<=n;d<<=1)if(pw[d]==1)break;
for(int i=0;i<n;i+=d){
    work(a+i);
    for(int j=1;j<=d;j++)add(ans[j],1ll*val[j]*kpow(pw[j],i)%m);
}
int ss=0;
for(int i=1;i<=d;i++)add(ss,1ll*ans[i]*(n/d)%m);
printf("%d\n",ss);
for(int i=0;i<n;i+=d)
    for(int j=1;j<=d;j++)printf("%d ",ans[j]);

问题可以转化为求 q^0, q^1, \ldots, q^{d-1} 处在一个 d 项多项式中的值。我们假设 qmd 次单位根,看看会出现什么情况。

我们回想一下普通 fft 在干什么事情。要想快速实现系数向点值的转化,我们将当前多项式按奇偶项分开,如 f(x)=\sum \limits_{i=0}^{15}a_ix^i,令 f_0(x)=\sum \limits_{i=0}^{7}a_{2i}x^{2i}f_1(x)=\sum \limits_{i=0}^{7}a_{2i+1}x^{2i},则 f(x)=f_0(x^2)+xf_1(x^2)。我们希望代入的两个自变量的平方相同(互为相反数),这样他们的函数值只有 x 这个位置是相反数,并且要方便递归。

单位根就有这样的优势,\omega_n^k=\omega_{n/2}^{k/2},但更重要的是,\omega_n^k+\omega_n^{k+n/2}=0

但是,q^{d/2} 有没有可能不是 -1 呢?这将导致 q^k+q^{k+d/2} \not \equiv 0\pmod m。我们可以在样例中找到这样的例子,d=271^2\equiv 1\pmod {80}71\not \equiv -1\pmod {80}

不过这样就做不了了吗? 答案是否定的。因为我们意识到,在 f(x)=f_0(x^2)+xf_1(x^2) 这一步中,最关键的是让两个自变量的平方相等,而 xf_1(x^2) 不一定要是相反数。对于这道题来说,q^{d/2} 是否是 -1 其实根本不重要,只需要 (q^{d/2})^21 就能保证 q^kq^{k+d/2} 两个数的平方相等。

于是这题就顺利地做完了。

Code

变量名和写法有所调整。

#include<bits/stdc++.h>
using namespace std;
inline int read(){
    int x=0;char c=getchar();
    while(c<'0'||c>'9')c=getchar();
    while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();
    return x;
}
const int N=1<<22;
int n,m,c,mod,g,a[N],pw[N],rev[N],ans[N];
void add(int &x,int y){x+=y;if(x>=mod)x-=mod;}
int kpow(int t1,int t2){
    int res=1;
    while(t2){
        if(t2&1)res=1ll*res*t1%mod;
        t1=1ll*t1*t1%mod;t2>>=1;
    }
    return res;
}
void ntt(int *f){
    int lg=__lg(n);
    for(int i=1;i<n;i++){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
        if(i<rev[i])swap(f[i],f[rev[i]]);
    }
    for(int x=1,y=2;y<=n;x<<=1,y<<=1){
        int z=pw[n/y];//求y次单位根
        for(int i=0;i<n;i+=y)
            for(int w=1,j=i;j<i+x;j++,w=1ll*w*z%mod){
                int p=f[j],q=1ll*w*f[j+x]%mod;
                f[j]=p+q<mod? p+q:p+q-mod;
                add(f[j+x]=1ll*pw[n>>1]*q%mod,p);
            }
    }
}
int main(){
    m=read();mod=read();c=read();
    pw[0]=1;
    for(int i=1;i<=m;i++)pw[i]=1ll*pw[i-1]*c%mod;
    for(n=1;n<=m;n<<=1)if(pw[n]==1)break;
    for(int i=0;i<m;i++)a[i]=read()%mod;
    for(int i=0;i<m;i+=n){
        ntt(a+i);
        for(int j=0;j<n;j++)add(ans[j],1ll*a[j+i]*kpow(pw[j],i)%mod);
    }
    int ss=0;
    for(int i=0;i<n;i++)add(ss,1ll*ans[i]*(m/n)%mod);
    printf("%d\n",ss);
    for(int i=0;i<m;i+=n){
        for(int j=1;j<n;j++)printf("%d ",ans[j]);
        printf("%d ",ans[0]);
    }
}