多项式exp--题解 P4726 【【模板】多项式指数函数(多项式 exp)】

· · 题解

有一个结论:如果将多项式的方程进行牛顿迭代,每次的逼近位数会翻倍。换而言之: 如果方程形如 $$F(G(x))\equiv 0\pmod{x^n}$$ 而已经得知 $$F(G_0(x))\equiv 0\pmod{x^{\lceil \frac n2\rceil}}$$ 则经过一次牛顿迭代后可推知上式。 可以考虑通过如下方式证明。 将 $F(G(x))$ 在 $G_0(x)$ 处泰勒展开: $$F(G(x))=\sum_{n=0}^{\infty}\frac{F^{(n)}(G_0(x))}{n!}(G(x)-G_0(x))^n\equiv 0\pmod{x^n}$$ 根据 $G_0\equiv G\pmod{x^{\lceil\frac n2\rceil}}$,得到 $G_0^2\equiv G^2\pmod{x^n}$(迭代后已求的系数不会改变) 又有 $(a-b)^n\equiv a^n-b^n\pmod n$[参考文章](https://www.luogu.com.cn/blog/void-basic-learner/polygon-square-root-enhanced),多项式即 $(F-G)^n\equiv F^n-G^n\pmod{x^n}$。 因此泰勒展开式第三项及以后的项都在 $\pmod{x^n}$ 下为 $0$。 $$ F(G_0(x))+F'(G_0(x))(G(x)-G_0(x))\equiv 0\pmod{x^n}$$ 得到迭代公式: $$G(x)\equiv G_0(x)-\frac{F(G_0(x))}{F'(G_0(x))}\pmod{x^n}$$ 则可以尝试递归地应用到本题。 --- 本题中求 $$G(x)\equiv e^{A(x)}\pmod{x^n}$$ 即 $$\ln G(x)-A(x)\equiv 0\pmod{x^n}$$ 就是这样一个方程。迭代的函数自然是 $$F(G(x))=\ln G(x) - A(x)$$ 考虑边界情况下, $n=1$。题目中给出了 $a_0=0$,则有 $\ln g_0-0=0$,常数项赋为 $1$ 即可。 否则运用迭代公式。其中有 $F'(G(x))=1/G(x)$: $$\begin{aligned} G(x)&\equiv G_0(x)-\frac{\ln G_0(x)-A(x)}{\frac 1 {G_0(x)}}\pmod{x^n} \\ &\equiv G_0(x)-G_0(\ln G_0(x)-A(x))\pmod{x^n} \\ &\equiv G_0(x)(1-\ln G_0(x)+A(x))\pmod{x^n} \end{aligned}$$ 于是每次递归中求一遍 $\ln$,一遍加减,再一遍乘法,一共 $O(n\log n)$,总共也是 $O(n\log n)$ 的。 ```cpp #include <cstdio> #include <iostream> #include <cstring> #include <cmath> #include <algorithm> typedef long long ll; const int MAXN = 400001; const int p = 998244353; inline int read() { int x = 0,f = 1; char ch = getchar(); while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); } do x = x * 10 + ch - 48, ch = getchar(); while(ch >= '0' && ch <= '9'); return x * f; } ll fastpow(ll a,int b) { ll res = 1; a %= p; while(b) { if(b & 1) res = res * a % p; a = a * a % p; b >>= 1; } return res; } int n,r[MAXN]; ll a[MAXN],b[MAXN],f[MAXN],g[MAXN]; void NTT(ll *a,int N) { for(int i = 0;i < N;i++) if(i < r[i]) std::swap(a[i],a[r[i]]); for(int n = 2, m = 1;n <= N;m = n, n <<= 1) { ll g1 = fastpow(3,(p - 1) / n),t1,t2; for(int l = 0;l < N;l += n) { ll g = 1; for(int i = l;i < l + m;i++) { t1 = a[i], t2 = g * a[i + m] % p; a[i] = (t1 + t2) % p; a[i + m] = (t1 + p - t2) % p; g = g * g1 % p; } } } return; } void INTT(ll *a,int N) { NTT(a,N); std::reverse(a + 1,a + N); ll invN = fastpow(N,p - 2); for(int i = 0;i < N;i++) a[i] = a[i] * invN % p; return; } void Dervt(ll *a,ll *b,int n) { for(int i = 0;i < n;i++) b[i] = a[i + 1] * (i + 1) % p; b[n - 1] = 0; return; } void Integ(ll *a,ll *b,int n) { for(int i = 0;i < n;i++) b[i + 1] = a[i] * fastpow(i + 1,p - 2) % p; b[0] = 0; return; } ll a1[MAXN]; void Inv(ll *a,ll *b,int n) { if(n == 1) return void(b[0] = fastpow(a[0],p - 2)); Inv(a,b,(n + 1) >> 1); int N = 1, l = -1; while(N <= n << 1) N <<= 1, l++; for(int i = 1;i < N;i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << l); for(int i = 0;i < n;i++) a1[i] = a[i]; for(int i = n;i < N;i++) a1[i] = 0; NTT(a1,N); NTT(b,N); for(int i = 0;i < N;i++) b[i] = ((b[i] << 1) % p + p - a1[i] * b[i] % p * b[i] % p) % p; INTT(b,N); for(int i = n;i < N;i++) b[i] = 0; return; } void Ln(ll *a,ll *b,int n) { memset(g,0,sizeof(g)); Dervt(a,f,n); Inv(a,g,n); int N = 1, l = -1; while(N <= n << 1) N <<= 1, l++; for(int i = 1;i < N;i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << l); NTT(f,N); NTT(g,N); for(int i = 0;i < N;i++) f[i] = f[i] * g[i] % p; INTT(f,N); Integ(f,b,n); return; } ll lnb[MAXN]; void Exp(ll *a,ll *b,int n) { if(n == 1) return void(b[0] = 1); Exp(a,b,(n + 1) >> 1); Ln(b,lnb,n); int N = 1, l = -1; while(N <= n << 1) N <<= 1, l++; for(int i = 1;i < N;i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << l); for(int i = 0;i < n;i++) a1[i] = a[i]; for(int i = n;i < N;i++) lnb[i] = a1[i] = 0; for(int i = 0;i < N;i++) a1[i] = ((a1[i] - lnb[i]) % p + p) % p; a1[0]++; NTT(b,N); NTT(a1,N); for(int i = 0;i < N;i++) b[i] = b[i] * a1[i] % p; INTT(b,N); for(int i = n;i < N;i++) b[i] = 0; return; } int main() { n = read(); for(int i = 0;i < n;i++) a[i] = read(); Exp(a,b,n); for(int i = 0;i < n;i++) std::printf("%lld ",b[i]); return 0; } ```