题解 P4238 【【模板】多项式求逆】

2018-02-25 14:58:50


递归求解。

求模 $x^n$ 的逆元时,假设先求出了模 $x^{\lceil\frac{ n}{2}\rceil}$ 的逆元。

设模 $x^n$ 逆元为 $B$ ,模 $x^{\lceil\frac{n}{2}\rceil}$ 逆元为 $B'$ ,则:

$$A*B'\equiv 1\pmod {x^{\lceil\frac{z}{2}\rceil}}$$

$$A*B\equiv 1\pmod {x^{\lceil\frac{z}{2}\rceil}}$$

$$\therefore B'-B\equiv 0\pmod {x^{\lceil\frac{z}{2}\rceil}}$$

两边平方,则

$$(B'-B)^2\equiv 0\pmod {x^z}$$

拆项,

$$B'^2-2BB'+B^2\equiv 0\pmod {x^z}$$

左右同乘$A$,

$$AB'^2-2B'+B\equiv 0\pmod {x^z}$$

移项,得

$$B\equiv 2B'-AB'^2\pmod {x^z}$$

然后就得到了关于 $B$ 的递推式。因为从上向下递归很麻烦,所以从下向上递推。从 $x^1$ 开始推至 $x^{2^z}(2^z>=n)$ 即可。初值是 $B=\{A(0)^{-1}\}$ 。

利用 NTT 可以将多项式乘法优化至$n\log n$,利用主定理计算得总时间复杂度为$O(n\log n)$。

代码:

#include<bits/stdc++.h>
#define For(i,a,b) for(i=(a);i<=(b);++i)
#define Forward(i,a,b) for(i=(a);i>=(b);--i)
#define Rep(i,a,b) for(register int i=(a),i##end=(b);i<=i##end;++i)
#define Repe(i,a,b) for(register int i=(a),i##end=(b);i>=i##end;--i)
using namespace std;
template<typename T>inline void read(T &x){
    T s=0,f=1;char k=getchar();
    while(!isdigit(k)&&k^'-')k=getchar();
    if(!isdigit(k)){f=-1;k=getchar();}
    while(isdigit(k)){s=s*10+(k^48);k=getchar();}
    x=s*f;
}
void file(void){
    #ifndef ONLINE_JUDGE
    freopen("water.in","r",stdin);
    freopen("water.out","w",stdout);
    #endif
}
const int MAXN=1<<20;
static int n,m,rev[MAXN],a[MAXN],b[2][MAXN];
inline void init()
{
    read(n);--n;
    Rep(i,0,n)read(a[i]);
    m=n<<1;
    for(n=2;n<=m;n<<=1);
}
inline void calrev(int n,int len)
{
    Rep(i,1,n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<len);
}
const int mod=998244353,gen=3;
inline int modu(long long x)
{
    if(x<mod)return x;
    return x-x/mod*mod;
}
inline int power(int x,int y)
{
    static int sum;
    for(sum=1;y;y>>=1,x=modu(1ll*x*x))
        if(y&1)sum=modu(1ll*sum*x);
    return sum;
}
static int P[MAXN],iv[MAXN];
inline void predone(int n)
{
    static int i,j;
    for(i=1,j=2;j<=n<<1;++i,j<<=1)
    {
        P[i]=power(gen,(mod-1)/j);
        iv[i]=power(P[i],mod-2);
    }
}
inline int modulo(int x,int y){x+=y;if(x>=mod)x-=mod;return x;}
inline void NTT(int x[],int type)
{
    Rep(i,1,n-1)if(i<rev[i])swap(x[i],x[rev[i]]);
    static int i,j,k,kk,t,w,wn,tk;
    for(i=2,tk=1;i<=n;i<<=1,++tk)
    {
        kk=i>>1;
        if(type==1)wn=P[tk];
        else wn=iv[tk];
        for(j=0;j<n;j+=i)
        {
            w=1;
            for(k=0;k<kk;++k,w=modu(1ll*w*wn))
            {
                t=modu(1ll*w*x[j+k+kk]);
                x[j+k+kk]=module(x[j+k],mod-t);
                x[j+k]=module(x[j+k],t);
            }
        }
    }
    if(type==-1)
    {
        int inv=power(n,mod-2);
        Rep(i,0,n)x[i]=modu(1ll*x[i]*inv);
    }
}
static int X[MAXN],Y[MAXN];
inline void mul(int x[],int y[])
{
    memset(X,0,sizeof X);
    memset(Y,0,sizeof Y);
    Rep(i,0,n>>1)X[i]=x[i],Y[i]=y[i];
    NTT(X,1);NTT(Y,1);
    Rep(i,0,n)X[i]=modu(1ll*X[i]*Y[i]);
    NTT(X,-1);
    Rep(i,0,n)x[i]=X[i];
}
static int c[MAXN];
inline void test(int z)
{
    memset(c,0,sizeof c);
    Rep(i,0,n)c[i]=a[i];
    mul(c,b[z]);Rep(i,0,m>>1)cout<<c[i]<<' ';puts("");
}
inline void solve()
{
    static int t=0,bas=1,len=1;
    b[0][0]=power(a[0],mod-2);
    n=4;
    calrev(n,len);
    while(bas<m)
    {
        t^=1;
        memset(b[t],0,sizeof b[t]);
        Rep(i,0,bas)b[t][i]=module(b[t^1][i]<<1,0);
        mul(b[t^1],b[t^1]);
        mul(b[t^1],a);
        Rep(i,0,bas)b[t][i]=module(b[t][i],mod-b[t^1][i]);
        bas<<=1;n<<=1;++len;
        if(bas<m)calrev(n,len);
    }
    Rep(i,0,m>>1)printf("%d ",b[t][i]);
    puts("");
}
int main(void){
    file();
    init();
    predone(n);
    solve();
    //cerr<<1.0*clock()/CLOCKS_PER_SEC<<endl;
    return 0;
}