题解 P4238 【【模板】多项式求逆】
Great_Influence
2018-02-25 14:58:50
递归求解。
求模 $x^n$ 的逆元时,假设先求出了模 $x^{\lceil\frac{
n}{2}\rceil}$ 的逆元。
设模 $x^n$ 逆元为 $B$ ,模 $x^{\lceil\frac{n}{2}\rceil}$ 逆元为 $B'$ ,则:
$$A*B'\equiv 1\pmod {x^{\lceil\frac{z}{2}\rceil}}$$
且
$$A*B\equiv 1\pmod {x^{\lceil\frac{z}{2}\rceil}}$$
$$\therefore B'-B\equiv 0\pmod {x^{\lceil\frac{z}{2}\rceil}}$$
两边平方,则
$$(B'-B)^2\equiv 0\pmod {x^z}$$
拆项,
$$B'^2-2BB'+B^2\equiv 0\pmod {x^z}$$
左右同乘$A$,
$$AB'^2-2B'+B\equiv 0\pmod {x^z}$$
移项,得
$$B\equiv 2B'-AB'^2\pmod {x^z}$$
然后就得到了关于 $B$ 的递推式。因为从上向下递归很麻烦,所以从下向上递推。从 $x^1$ 开始推至 $x^{2^z}(2^z>=n)$ 即可。初值是 $B=\{A(0)^{-1}\}$ 。
利用 NTT 可以将多项式乘法优化至$n\log n$,利用主定理计算得总时间复杂度为$O(n\log n)$。
代码:
```cpp
#include<bits/stdc++.h>
#define For(i,a,b) for(i=(a);i<=(b);++i)
#define Forward(i,a,b) for(i=(a);i>=(b);--i)
#define Rep(i,a,b) for(register int i=(a),i##end=(b);i<=i##end;++i)
#define Repe(i,a,b) for(register int i=(a),i##end=(b);i>=i##end;--i)
using namespace std;
template<typename T>inline void read(T &x){
T s=0,f=1;char k=getchar();
while(!isdigit(k)&&k^'-')k=getchar();
if(!isdigit(k)){f=-1;k=getchar();}
while(isdigit(k)){s=s*10+(k^48);k=getchar();}
x=s*f;
}
void file(void){
#ifndef ONLINE_JUDGE
freopen("water.in","r",stdin);
freopen("water.out","w",stdout);
#endif
}
const int MAXN=1<<20;
static int n,m,rev[MAXN],a[MAXN],b[2][MAXN];
inline void init()
{
read(n);--n;
Rep(i,0,n)read(a[i]);
m=n<<1;
for(n=2;n<=m;n<<=1);
}
inline void calrev(int n,int len)
{
Rep(i,1,n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<len);
}
const int mod=998244353,gen=3;
inline int modu(long long x)
{
if(x<mod)return x;
return x-x/mod*mod;
}
inline int power(int x,int y)
{
static int sum;
for(sum=1;y;y>>=1,x=modu(1ll*x*x))
if(y&1)sum=modu(1ll*sum*x);
return sum;
}
static int P[MAXN],iv[MAXN];
inline void predone(int n)
{
static int i,j;
for(i=1,j=2;j<=n<<1;++i,j<<=1)
{
P[i]=power(gen,(mod-1)/j);
iv[i]=power(P[i],mod-2);
}
}
inline int modulo(int x,int y){x+=y;if(x>=mod)x-=mod;return x;}
inline void NTT(int x[],int type)
{
Rep(i,1,n-1)if(i<rev[i])swap(x[i],x[rev[i]]);
static int i,j,k,kk,t,w,wn,tk;
for(i=2,tk=1;i<=n;i<<=1,++tk)
{
kk=i>>1;
if(type==1)wn=P[tk];
else wn=iv[tk];
for(j=0;j<n;j+=i)
{
w=1;
for(k=0;k<kk;++k,w=modu(1ll*w*wn))
{
t=modu(1ll*w*x[j+k+kk]);
x[j+k+kk]=module(x[j+k],mod-t);
x[j+k]=module(x[j+k],t);
}
}
}
if(type==-1)
{
int inv=power(n,mod-2);
Rep(i,0,n)x[i]=modu(1ll*x[i]*inv);
}
}
static int X[MAXN],Y[MAXN];
inline void mul(int x[],int y[])
{
memset(X,0,sizeof X);
memset(Y,0,sizeof Y);
Rep(i,0,n>>1)X[i]=x[i],Y[i]=y[i];
NTT(X,1);NTT(Y,1);
Rep(i,0,n)X[i]=modu(1ll*X[i]*Y[i]);
NTT(X,-1);
Rep(i,0,n)x[i]=X[i];
}
static int c[MAXN];
inline void test(int z)
{
memset(c,0,sizeof c);
Rep(i,0,n)c[i]=a[i];
mul(c,b[z]);Rep(i,0,m>>1)cout<<c[i]<<' ';puts("");
}
inline void solve()
{
static int t=0,bas=1,len=1;
b[0][0]=power(a[0],mod-2);
n=4;
calrev(n,len);
while(bas<m)
{
t^=1;
memset(b[t],0,sizeof b[t]);
Rep(i,0,bas)b[t][i]=module(b[t^1][i]<<1,0);
mul(b[t^1],b[t^1]);
mul(b[t^1],a);
Rep(i,0,bas)b[t][i]=module(b[t][i],mod-b[t^1][i]);
bas<<=1;n<<=1;++len;
if(bas<m)calrev(n,len);
}
Rep(i,0,m>>1)printf("%d ",b[t][i]);
puts("");
}
int main(void){
file();
init();
predone(n);
solve();
//cerr<<1.0*clock()/CLOCKS_PER_SEC<<endl;
return 0;
}
```