P4238 【模板】多项式乘法逆 题解

· · 题解

多项式求逆

Preface

本篇与其他的题解推导略有不同,主要使用了牛顿迭代法,虽然最后的式子都是一样的,但是该方法泛用性更加广泛一些。

Description

已知 F(x),要求出 G(x) * F(x) \equiv 1 \pmod{x^n},系数对 998244353 取模。

Solution

对这个式子进行推导,我们将一个多项式变成一个简单的未知数,然后思考这个式子。

移项

F(x) - \frac{1}{G(x)} \equiv 0 \pmod{x^n}

这个东西可以用牛顿迭代公式计算。(OIer 们习惯性去直接背诵这个公式)

接下来是牛顿迭代的推导,不想看的同学可以直接快进到结论。

已知一个复合多项式

G(F(x)) \equiv 0 \pmod{x^n}

假设已知 F_0(x)\equiv F(x) \pmod{x^{\lceil\frac{n}{2}\rceil}},求 F_1(x) \equiv F(x) \pmod{x^n}

根据泰勒公式,在 F_0(x) 处展开 G(F(x))

G(F(x)) = \sum_{i=0}^{\infty}\frac{G^{(i)}(F_0(x))}{i!}(F(x)-F_0(x))^i

可以发现, F(x)-F_0(x) \equiv 0 \pmod{x^{\lceil\frac{n}{2}\rceil}},所以 (F(x)-F_0(x))^k \equiv 0 \pmod{x^n}k \ge 2 的时候成立。

因此:

G(F_1(x)) \equiv G(F_0(x)) + G'(F_0(x))(F_1(x) - F_0(x)) \pmod{x^n}

G(F_1(x)) \equiv 0 \pmod{x^n},整理得到:

F_1(x)=F_0(x)-\frac{G(F_0(x))}{G'(F_0(x))}

许多人熟知的牛顿迭代公式就是这个玩意了,然后这个东西的应用也非常简单。

以这个题目为例:

H(t)=F(x)-\frac{1}{t},那么就可以得到:

H(G(x)) \equiv 0 \pmod{x^n}

然后套入牛顿迭代公式:

G_1(x)=G_0(x)-\frac{H(G_0(x))}{H'(G_0(x))}

H(t) = F(x) - \frac{1}{t} 套回去,注意这里的 F(x) 表示的是一个常数,那么就可以得到:

G_1(x)=G_0(x) - \frac{F(x)-\frac{1}{G_0(x)}}{\frac{1}{G_0(x)^2}}

化简一下可以得到:

G_1(x)=G_0(x)(2-G_0(x)F(x))

打个 NTT 求解一下就可以了。注意边界情况!

code

#include <bits/stdc++.h>

using namespace std;

#define int long long

#define FOR(i,a,b) for(int i=(a),i##i=(b);i<=i##i;i++)
#define ROF(i,a,b) for(int i=(a),i##i=(b);i>=i##i;i--)

#define gc (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
char buf[100000], *p1(buf), *p2(buf);
#define rd read()
inline int read() {
    int x = 0, f = 1;
    char ch = gc;
    while(!isdigit(ch)) {
        if(ch == '-') f = 0;
        ch = gc;
    }
    while(isdigit(ch)) {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = gc;
    }
    return f ? x : -x;
}

const int N = 3e5 + 1;
const int mod = 998244353, g = 3;

inline int ksm(int x, int y) {
    int ret = 1;
    for(; y; y >>= 1, x = x * x % mod)
        if(y & 1) ret = ret * x % mod;
    return ret;
}

int rev[N];

inline void NTT(int *a, int n, int typ) {
    for(int i = 0; i < n; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int i = 1; i < n; i <<= 1) {
        int gn = ksm(g, (mod - 1) / (i << 1));
        for(int j = 0, g0 = 1, x, y; j < n; j += (i << 1), g0 = 1) 
        for(int k = 0; k < i; k++, g0 = gn * g0 % mod) {
            x = a[j + k], y = a[i + j + k] * g0 % mod;
            a[j + k] = (x + y) % mod;
            a[i + j + k] = (x - y + mod) % mod;
        }
    }
    if(typ == 1) return ;
    int inv = ksm(n, mod - 2); reverse(a + 1, a + n);
    for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
    return ;
}

void INV(int *b, int *a, int n) {
    if(n == 1) return b[0] = ksm(a[0], mod - 2), void();
    INV(b, a, (n + 1) >> 1);
    static int c[N];
    int len = 1, p = -1; while(len < (n << 1)) len <<= 1, p++;
    FOR(i, 1, len - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << p);
    copy(a, a + n, c), fill(c + n, c + len, 0);
    NTT(c, len, 1), NTT(b, len ,1);
    FOR(i, 0, len - 1) b[i] = (2 - b[i] * c[i] % mod + mod) % mod * b[i] % mod;
    NTT(b, len, 0), fill(b + n, b + len, 0);
    return ;
}

int a[N], b[N], n;

#undef int

inline void work() {
    n = rd;
    FOR(i, 0, n - 1) a[i] = rd;
    INV(b, a, n);
    FOR(i, 0, n - 1) cout << b[i] << ' ';
    return ;
}

int main() {
    work();
    return 0;
} 

参考文章