【题解】多项式反三角函数

· · 题解

\text{updated:}

之前的代码数组开的不够大,导致某些抄代码的同学RE了,在此表示非常抱歉!现已修复。

要做这道题,得先知道 多项式开根 怎么做。

给定一个多项式 F(x),求一个 G(x),满足

G(x)\equiv\sin^{-1}F(x)\space (\text{mod }x^n)

考虑将两遍求导,得到

G'(x)\equiv \frac{F'(x)}{\sqrt{1-F(x)^2}}\space ({\text{mod }x^n})

然后再积分回来

G(x)\equiv \int \frac{F'(x)}{\sqrt{1-F(x)^2}}\text dx\space ({\text{mod }x^n})

这里用到的思想,和求多项式 \ln 的地方很类似。

\tan^{-1} 的方式,也差不多

H(x)\equiv\tan^{-1}F(x)\space (\text{mod }x^n) H'(x)\equiv\frac{F'(x)}{1+F(x)^2}\space (\text{mod }x^n) H(x)\equiv\int\frac{F'(x)}{1+F(x)^2} \text dx\space (\text{mod }x^n)

然后直接计算即可。
Code:

#pragma GCC optimize ("unroll-loops")
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define N 262147
#define ll long long
#define reg register
#define p 998244353
using namespace std;

inline int add(const int& x,const int& y){ return x+y>=p?x+y-p:x+y; }
inline int dec(const int& x,const int& y){ return x<y?x-y+p:x-y; }

inline int power(int a,int t){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%p;
        a = (ll)a*a%p;
        t >>= 1;
    }
    return res;
}

inline int getlen(int n){
    return 1<<(32-__builtin_clz(n));
}

int rt[N],rev[N],inv[N];
int siz;

void init(int n){
    int w,lim = 1;
    while(lim<=n) lim <<= 1,++siz;
    for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
    w = power(3,(p-1)>>siz);
    inv[1] = rt[lim>>1] = 1;
    for(reg int i=lim>>1|1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
    for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
    for(reg int i=2;i<=n;++i) inv[i] = (ll)(p-p/i)*inv[p%i]%p;
}

inline void NTT(int *f,int type,int lim){
    if(type==-1) reverse(f+1,f+lim);
    static unsigned long long a[N];
    reg int x,shift = siz-__builtin_ctz(lim);
    for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
    for(reg int mid=1;mid!=lim;mid<<=1)
    for(reg int j=0;j!=lim;j+=(mid<<1))
    for(reg int k=0;k!=mid;++k){
        x = a[j|k|mid]*rt[mid|k]%p;
        a[j|k|mid] = a[j|k]+p-x;
        a[j|k] += x;
    }
    for(reg int i=0;i!=lim;++i) f[i] = a[i]%p;
    if(type==1) return;
    x = p-((p-1)>>(siz-shift));
    for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%p;
}

inline void inverse(const int *f,int n,int *R){
    static int g[N],h[N],s[30];
    memset(g,0,getlen(n<<1)<<2);
    int lim = 1,top = 0;
    while(n){
        s[++top] = n;
        n >>= 1;
    }
    g[0] = power(f[0],p-2);
    while(top--){
        n = s[top+1];
        while(lim<=(n<<1)) lim <<= 1;
        memcpy(h,f,(n+1)<<2);
        memset(h+n+1,0,(lim-n)<<2);
        NTT(g,1,lim),NTT(h,1,lim);
        for(reg int i=0;i!=lim;++i) g[i] = g[i]*(2-(ll)g[i]*h[i]%p+p)%p;
        NTT(g,-1,lim);
        memset(g+n+1,0,(lim-n)<<2);
    }
    memcpy(R,g,(n+1)<<2);
}

inline void sqrt(const int *f,int n,int *R){
    static int g[N],h[N];
    memset(g,0,getlen(n<<1)<<2);
    int lim = 1,top = 0;
    int s[30];
    while(n){
        s[++top] = n;
        n >>= 1;
    }
    g[0] = 1;
    while(top--){
        n = s[top+1];
        while(lim<=(n<<1)) lim <<= 1;
        memcpy(h,g,(n+1)<<2);
        for(int i=0;i<=n;++i) h[i] = add(h[i],h[i]);
        inverse(h,n,h);
        NTT(g,1,lim);
        for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*g[i]%p;
        NTT(g,-1,lim);
        for(reg int i=0;i<=n;++i) g[i] = add(g[i],f[i]);
        memset(g+n+1,0,(lim-n)<<2);
        NTT(g,1,lim),NTT(h,1,lim);
        for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*h[i]%p;
        NTT(g,-1,lim);
        memset(g+n+1,0,(lim-n)<<2);
    }
    memcpy(R,g,(n+1)<<2);
}

inline void asin(const int *f,int n,int *R){
    static int g[N],h[N];
    int lim = getlen(n<<1);
    memcpy(g,f,(n+1)<<2);
    memset(g+n+1,0,(lim-n)<<2);
    NTT(g,1,lim);
    for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*g[i]%p;
    NTT(g,-1,lim);
    memset(g+n+1,0,(lim-n)<<2);
    for(reg int i=0;i<=n;++i) g[i] = g[i]?p-g[i]:0;
    ++g[0];
    sqrt(g,n,g);
    inverse(g,n,g);
    for(reg int i=0;i!=n;++i) h[i] = (ll)(i+1)*f[i+1]%p;
    memset(h+n,0,(lim-n+1)<<2);
    NTT(g,1,lim),NTT(h,1,lim);
    for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*h[i]%p;
    NTT(g,-1,lim);
    for(reg int i=1;i<=n;++i) R[i] = (ll)g[i-1]*inv[i]%p;
    R[0] = 0;
}

inline void atan(const int *f,int n,int *R){
    static int g[N],h[N];
    int lim = getlen(n<<1);
    memcpy(g,f,(n+1)<<2);
    memset(g+n+1,0,(lim-n)<<2);
    NTT(g,1,lim);
    for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*g[i]%p;
    NTT(g,-1,lim);
    memset(g+n+1,0,(lim-n)<<2);
    ++g[0];
    inverse(g,n,g);
    for(reg int i=0;i!=n;++i) h[i] = (ll)(i+1)*f[i+1]%p;
    memset(h+n,0,(lim-n+1)<<2);
    NTT(g,1,lim),NTT(h,1,lim);
    for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*h[i]%p;
    NTT(g,-1,lim);
    for(reg int i=1;i<=n;++i) R[i] = (ll)g[i-1]*inv[i]%p;
    R[0] = 0;
}

int n,tp;
int F[N];

int main(){
    scanf("%d%d",&n,&tp);
    init(n<<2);
    for(reg int i=0;i!=n;++i) scanf("%d",&F[i]);
    if(tp==0) asin(F,n-1,F);
    else atan(F,n-1,F);
    for(reg int i=0;i!=n;++i) printf("%d ",F[i]);
    return 0;   
}