题解 P4238 【【模板】多项式求逆】

Great_Influence

2018-02-25 14:58:50

Solution

递归求解。 求模 $x^n$ 的逆元时,假设先求出了模 $x^{\lceil\frac{ n}{2}\rceil}$ 的逆元。 设模 $x^n$ 逆元为 $B$ ,模 $x^{\lceil\frac{n}{2}\rceil}$ 逆元为 $B'$ ,则: $$A*B'\equiv 1\pmod {x^{\lceil\frac{z}{2}\rceil}}$$ 且 $$A*B\equiv 1\pmod {x^{\lceil\frac{z}{2}\rceil}}$$ $$\therefore B'-B\equiv 0\pmod {x^{\lceil\frac{z}{2}\rceil}}$$ 两边平方,则 $$(B'-B)^2\equiv 0\pmod {x^z}$$ 拆项, $$B'^2-2BB'+B^2\equiv 0\pmod {x^z}$$ 左右同乘$A$, $$AB'^2-2B'+B\equiv 0\pmod {x^z}$$ 移项,得 $$B\equiv 2B'-AB'^2\pmod {x^z}$$ 然后就得到了关于 $B$ 的递推式。因为从上向下递归很麻烦,所以从下向上递推。从 $x^1$ 开始推至 $x^{2^z}(2^z>=n)$ 即可。初值是 $B=\{A(0)^{-1}\}$ 。 利用 NTT 可以将多项式乘法优化至$n\log n$,利用主定理计算得总时间复杂度为$O(n\log n)$。 代码: ```cpp #include<bits/stdc++.h> #define For(i,a,b) for(i=(a);i<=(b);++i) #define Forward(i,a,b) for(i=(a);i>=(b);--i) #define Rep(i,a,b) for(register int i=(a),i##end=(b);i<=i##end;++i) #define Repe(i,a,b) for(register int i=(a),i##end=(b);i>=i##end;--i) using namespace std; template<typename T>inline void read(T &x){ T s=0,f=1;char k=getchar(); while(!isdigit(k)&&k^'-')k=getchar(); if(!isdigit(k)){f=-1;k=getchar();} while(isdigit(k)){s=s*10+(k^48);k=getchar();} x=s*f; } void file(void){ #ifndef ONLINE_JUDGE freopen("water.in","r",stdin); freopen("water.out","w",stdout); #endif } const int MAXN=1<<20; static int n,m,rev[MAXN],a[MAXN],b[2][MAXN]; inline void init() { read(n);--n; Rep(i,0,n)read(a[i]); m=n<<1; for(n=2;n<=m;n<<=1); } inline void calrev(int n,int len) { Rep(i,1,n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<len); } const int mod=998244353,gen=3; inline int modu(long long x) { if(x<mod)return x; return x-x/mod*mod; } inline int power(int x,int y) { static int sum; for(sum=1;y;y>>=1,x=modu(1ll*x*x)) if(y&1)sum=modu(1ll*sum*x); return sum; } static int P[MAXN],iv[MAXN]; inline void predone(int n) { static int i,j; for(i=1,j=2;j<=n<<1;++i,j<<=1) { P[i]=power(gen,(mod-1)/j); iv[i]=power(P[i],mod-2); } } inline int modulo(int x,int y){x+=y;if(x>=mod)x-=mod;return x;} inline void NTT(int x[],int type) { Rep(i,1,n-1)if(i<rev[i])swap(x[i],x[rev[i]]); static int i,j,k,kk,t,w,wn,tk; for(i=2,tk=1;i<=n;i<<=1,++tk) { kk=i>>1; if(type==1)wn=P[tk]; else wn=iv[tk]; for(j=0;j<n;j+=i) { w=1; for(k=0;k<kk;++k,w=modu(1ll*w*wn)) { t=modu(1ll*w*x[j+k+kk]); x[j+k+kk]=module(x[j+k],mod-t); x[j+k]=module(x[j+k],t); } } } if(type==-1) { int inv=power(n,mod-2); Rep(i,0,n)x[i]=modu(1ll*x[i]*inv); } } static int X[MAXN],Y[MAXN]; inline void mul(int x[],int y[]) { memset(X,0,sizeof X); memset(Y,0,sizeof Y); Rep(i,0,n>>1)X[i]=x[i],Y[i]=y[i]; NTT(X,1);NTT(Y,1); Rep(i,0,n)X[i]=modu(1ll*X[i]*Y[i]); NTT(X,-1); Rep(i,0,n)x[i]=X[i]; } static int c[MAXN]; inline void test(int z) { memset(c,0,sizeof c); Rep(i,0,n)c[i]=a[i]; mul(c,b[z]);Rep(i,0,m>>1)cout<<c[i]<<' ';puts(""); } inline void solve() { static int t=0,bas=1,len=1; b[0][0]=power(a[0],mod-2); n=4; calrev(n,len); while(bas<m) { t^=1; memset(b[t],0,sizeof b[t]); Rep(i,0,bas)b[t][i]=module(b[t^1][i]<<1,0); mul(b[t^1],b[t^1]); mul(b[t^1],a); Rep(i,0,bas)b[t][i]=module(b[t][i],mod-b[t^1][i]); bas<<=1;n<<=1;++len; if(bas<m)calrev(n,len); } Rep(i,0,m>>1)printf("%d ",b[t][i]); puts(""); } int main(void){ file(); init(); predone(n); solve(); //cerr<<1.0*clock()/CLOCKS_PER_SEC<<endl; return 0; } ```