P7848

· · 题解

方便起见,下文用 n 表示原题中的 2^n

我们注意到 \text{DFT} 这个东西实际上就是乘一个范德蒙德矩阵:

V= \begin{bmatrix} 1&1&1&\cdots&1\\ 1&\omega&\omega^2&\cdots&\omega^{n-1}\\ 1&\omega^2&\omega^4&\cdots&\omega^{2(n-1)}\\ 1&\omega^3&\omega^6&\cdots&\omega^{3(n-1)}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 1&\omega^{n-1}&\omega^{2(n-1)}&\cdots&\omega^{(n-1)(n-1)}\\ \end{bmatrix}

因此计算 \text{DFT}^k 实际也就是计算这个矩阵的 k 次方。

我们先手算一下 V^2 试试:

V^2_{ij}=\sum_{k=0}^{n-1}\omega^{ik+kj}=\left(\sum_{k=0}^{n-1}\omega^{k}\right)\omega^{i+j}

注意到除非 i+j=ni+j=0,那么这个 \sum\omega^k 就会直接变成 0 把整个消掉。

也就是说,V 平方之后大概是个这个样子:

V^2= \begin{bmatrix} n&0&0&\cdots&0\\ 0&0&0&\cdots&n\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 0&0&n&\cdots&0\\ 0&n&0&\cdots&0\\ \end{bmatrix}

我们惊讶地发现把一个东西乘上 V^2 实际上就等价于:

每一项都乘上 n,然后翻转后 n-1

也就是说,如果再乘上 V^4,后 n-1 项转了回来,那就相当于每一项乘上 n^2

那么我们先算出来 \left\lfloor\dfrac{k}{4}\right\rfloor 先乘一下,再做 k\bmod 4 次 NTT 就行了。

时间复杂度为 O(n\log n+n\log k)

#include<complex>
#include<iostream>
#include<cmath>
#include<cstdlib>
#include<cstdio>
#include<algorithm>

#define int long long
#define MAXN 2100000

using namespace std;

int x[MAXN],y[MAXN],z[MAXN];
int N,M;

int read(){
    int x=0;char c=getchar();
    for(;(c<'0'||c>'9');c=getchar());
    for(;(c>='0'&&c<='9');c=getchar())x=x*10+(c&15);
    return x;
}

int rev[MAXN],k,L;

const int mod=998244353;
const int g=3;
const int g1=332748118;

void init(){
    rev[0]=0;
    for(int i=0;i<k;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
}

int ksm(int x,int y){
    x%=mod;int ans=1;
    for(int i=y;i;i>>=1,x=x*x%mod)if(i&1)ans=ans*x%mod;
    return ans%mod;
}

void ntt(int *a,int sz,int tag){
    for(int i=1;i<sz;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    int n=log2(sz);
    for(int i=1;i<=n;i++){
        int m=1<<i,wn;
        if(tag==1)wn=ksm(g,(mod-1)/(m));
        else wn=ksm(g1,(mod-1)/(m));
        for(int j=0;j<sz;j+=m){
            int omega=1;
            for(int l=0;l<m/2;l++){
                int x=omega*a[j+l+m/2]%mod,y=a[j+l]%mod;
                a[j+l]=(y+x)%mod;a[j+l+m/2]=(y-x+mod)%mod;
                omega=omega*wn%mod;
            }
        }
    }
}

int K,qwq;

signed main(){

    qwq=read(),K=read();
    N=(1<<qwq)-1;
    for(int i=0;i<=N;i++)x[i]=(read()+mod)%mod;

    for(k=1,L=0;k<=N;k<<=1,L++);

    int qwq=(int)(K/4)*2;K%=4;
    for(int i=0;i<=N;i++)x[i]=x[i]*(ksm((N+1),qwq)%mod)%mod;

    init();
    for(int i=1;i<=K;i++)ntt(x,k,1);

    for(int i=0;i<=N;i++)cout<<x[i]%mod<<" ";
    puts("");

    return 0;
}