题解:P10249 【模板】多项式复合函数(加强版)

· · 题解

多项式复合的 O(n\log^2n) Bostan–Mori 算法

给出一种不需要转置原理的解释,但是为了导出整个递归算法的不变量,可能需要先理解这个算法。

给出 f(x)=\sum_{k=0}^{n-1}f_kx^k \in\mathbb{C}\left\lbrack x\right\rbrackg(x)\in\mathbb{C}\left\lbrack x\right\rbrack 那么

f(g)=\sum_{k\geq 0}f_kg^k

我们考虑多项式复合 f(g)\bmod{x^n},因为本题中如果需要计算形式幂级数的复合,那么会要求 g(0)=0

考虑

\frac{f(y^{-1})}{1-yg(x)}=\sum_{k\geq 0}(\cdots +f_ky^{-k}+\cdots)y^kg(x)^k \in\mathbb{C}\left\lbrack x\right\rbrack \left(\left( y\right)\right)

那么我们的目标就是求算

\left\lbrack y^0\right\rbrack\frac{f(y^{-1})}{1-yg(x)}\bmod{x^n}

根据 Bostan–Mori 算法有

\begin{aligned} \frac{P(y)}{Q(x,y)}&=\frac{P(y)}{Q(x,y)Q(-x,y)}Q(-x,y) \\ &=\frac{P(y)}{V(x^2,y)}Q(-x,y) \end{aligned}

我们先考虑计算 \dfrac{P(y)}{V(x^2,y)}\bmod{x^n},为了能够计算这一部分,我们需要先进行一次多项式乘法计算出 V(x^2,y)\bmod{x^n} 然后设置子问题为 \dfrac{P(y)}{V(x,y)}\bmod{x^{\left\lfloor n/2\right\rfloor}},注意到 \deg_y(V)=2\deg_y(Q)x 的次数只需要一半所以子问题的规模还是基本相同的,而在 \dfrac{P(y)}{V(x,y)}\bmod{x} 时我们可以解决该问题,因为此时 V(x,y)\bmod{x} 实际上是一个(常数项为 1 的)一元多项式,使用形式幂级数的乘法逆元算法即可求算 P/V

剩下的问题是递归算法该返回什么,考虑我们求出 \dfrac{P(y)}{V(x^2,y)}\bmod{x^n} 后需要将其乘以 Q(-x,y) 那么 y^{>0} 的部分不会对我们最后要求的 y^0 的系数产生影响,所以子问题应该返回 y^{-2\deg_y(Q)+1},\dots ,y^0 的系数,当前问题返回 y^{-\deg_y(Q)+1},\dots ,y^0 的系数因为子问题的 y^{\leq -2\deg_y(Q)} 的系数不会对当前问题的 y^{\geq -\deg_y(Q)+1} 的系数产生影响,而当前问题作为上一次递归的子问题也是同样的,如此我们可以构建出递归的不变量然后给出伪代码:

\begin{array}{ll} &\textbf{Algorithm }\operatorname{Composition-Subprocedure}(P,Q,n)\text{:} \\ &\textbf{Input}\text{: }P=\sum_{0\leq j\leq n}p_jy^{-j}\in\mathbb{C}((y)),Q\in\mathbb{C}\left\lbrack\left\lbrack x\right\rbrack\right\rbrack\left\lbrack y\right\rbrack\text{.} \\ &\textbf{Output}\text{: }\left\lbrack y^{\left\lbrack -\deg_y(Q)+1,0\right\rbrack}\right\rbrack\dfrac{P}{Q}\bmod{x^{n+1}}\text{.} \\ 1&d\gets \deg_y(Q)\\ 2&\textbf{if }n=0\textbf{ then return }\left(\left\lbrack y^{-d+1}\right\rbrack P/Q,\dots ,\left\lbrack y^0\right\rbrack P/Q\right) \\ 3&V(x^2,y)\gets Q(x,y)Q(-x,y)\bmod{x^{n+1}} \\ 4&(t_{-2d+1},\dots ,t_0)\gets \operatorname{Composition-Subprocedure}\left(P,V(x,y),\left\lfloor n/2\right\rfloor\right) \\ 5&T(x,y)\gets \sum_{j=-2d+1}^0t_jy^j \\ 6&U(x,y)=\sum_{j=-2d+1}^d u_jy^j\gets T(x^2,y)Q(-x,y)\bmod{x^{n+1}} \\ 7&\textbf{return }\left(u_{-d+1},\dots ,u_0\right) \end{array}

那么复合的代码就是

\begin{array}{ll} &\textbf{Algorithm }\operatorname{Composition}(f,g,n)\text{:} \\ &\textbf{Input}\text{: }f,g\in\mathbb{C}\left\lbrack x\right\rbrack ,n\in\mathbb{Z}_{>0}\text{.} \\ &\textbf{Output}\text{: }f(g)\bmod{x^n}\text{.} \\ 1&(u)\gets\operatorname{Composition-Subprocedure}\left(f\left(y^{-1}\right),1-yg(x),n-1\right) \\ 2&\textbf{return }(u) \end{array}

注意一个隐藏的约束就是递归终止时的 \deg_y(Q) 需要大于递归开始时的 n,否则的话我们需要的信息就没有办法正确保留,这是容易验证的。

补充

对于递归结束的计算 P/Q 真的需要计算形式幂级数的乘法逆元吗?noshi91 在 X(Twitter)上告诉我考虑

\begin{aligned} \left\lbrack x^0\right\rbrack Q&=\left(1-y\left\lbrack x^0\right\rbrack g\right)^{\deg_y(Q)}\\ &=\sum_{j=0}^{\deg_y(Q)}\binom{\deg_y(Q)}{j}\left(-\left\lbrack x^0\right\rbrack g\right)^{j}y^j \end{aligned}

那么

\left\lbrack x^0\right\rbrack Q^{-1}=\sum_{j\geq 0}\binom{\deg_y(Q)+j-1}{j}\left(\left\lbrack x^0\right\rbrack g\right)^{j}y^j

所以最后我们只需一次多项式乘法。Thanks noshi91!

应用

我们可以发现下面的形式幂级数运算都可以通过复合解决:

实现

我们在 C++17 标准下实现,注意 ntt 之后的结果是“位逆序”的,而 intt 的参数需要“位逆序”的,这样可以很方便的应用我们的优化。

#include <algorithm>
#include <cassert>
#include <iostream>
#include <type_traits>
#include <vector>

template <unsigned Mod>
class ModInt {
    static_assert((Mod >> 31) == 0, "`Mod` must less than 2^(31)");
    static unsigned safe_mod(int v) {
        if ((v %= (int)Mod) < 0) v += (int)Mod;
        return v;
    }

    struct private_constructor_t {};
    static inline private_constructor_t private_constructor{};
    ModInt(private_constructor_t, unsigned v) : v_(v) {}

    unsigned v_;

public:
    static unsigned mod() { return Mod; }
    static ModInt from_raw(unsigned v) { return ModInt(private_constructor, v); }
    ModInt() : v_() {}
    ModInt(int v) : v_(safe_mod(v)) {}
    unsigned val() const { return v_; }

    ModInt operator-() const { return from_raw(v_ == 0 ? v_ : Mod - v_); }
    ModInt pow(int e) const {
        if (e < 0) return inv().pow(-e);
        for (ModInt x(*this), res(from_raw(1));; x *= x) {
            if (e & 1) res *= x;
            if ((e >>= 1) == 0) return res;
        }
    }
    ModInt inv() const {
        int x1 = 1, x3 = 0, a = v_, b = Mod;
        while (b) {
            int q = a / b, x1_old = x1, a_old = a;
            x1 = x3, x3 = x1_old - x3 * q, a = b, b = a_old - b * q;
        }
        return from_raw(x1 < 0 ? x1 + (int)Mod : x1);
    }
    std::enable_if_t<(Mod & 1), ModInt> div_by_2() const {
        if ((v_ & 1) == 0) return from_raw(v_ >> 1);
        return from_raw((v_ + Mod) >> 1);
    }

    ModInt &operator+=(const ModInt &a) {
        if ((v_ += a.v_) >= Mod) v_ -= Mod;
        return *this;
    }
    ModInt &operator-=(const ModInt &a) {
        if ((v_ += Mod - a.v_) >= Mod) v_ -= Mod;
        return *this;
    }
    ModInt &operator*=(const ModInt &a) {
        v_ = (unsigned long long)v_ * a.v_ % Mod;
        return *this;
    }
    ModInt &operator/=(const ModInt &a) { return *this *= a.inv(); }

    friend ModInt operator+(const ModInt &a, const ModInt &b) { return ModInt(a) += b; }
    friend ModInt operator-(const ModInt &a, const ModInt &b) { return ModInt(a) -= b; }
    friend ModInt operator*(const ModInt &a, const ModInt &b) { return ModInt(a) *= b; }
    friend ModInt operator/(const ModInt &a, const ModInt &b) { return ModInt(a) /= b; }
    friend bool operator==(const ModInt &a, const ModInt &b) { return a.v_ == b.v_; }
    friend bool operator!=(const ModInt &a, const ModInt &b) { return a.v_ != b.v_; }
    friend std::istream &operator>>(std::istream &a, ModInt &b) {
        int v;
        a >> v;
        b.v_ = safe_mod(v);
        return a;
    }
    friend std::ostream &operator<<(std::ostream &a, const ModInt &b) { return a << b.v_; }
};

template <typename Tp>
class NttInfo {
    static Tp least_quadratic_nonresidue() {
        for (int i = 2;; ++i)
            if (Tp::from_raw(i).pow((Tp::mod() - 1) / 2) == -1) return Tp::from_raw(i);
    }

    const int ordlog2_;
    const Tp zeta_;
    const Tp invzeta_;

    mutable std::vector<Tp> root_;
    mutable std::vector<Tp> invroot_;

    NttInfo()
        : ordlog2_(__builtin_ctz(Tp::mod() - 1)),
          zeta_(least_quadratic_nonresidue().pow((Tp::mod() - 1) >> ordlog2_)),
          invzeta_(zeta_.inv()), root_{Tp::from_raw(1)}, invroot_{Tp::from_raw(1)} {}

public:
    static const NttInfo &get() {
        static NttInfo info;
        return info;
    }

    Tp zeta() const { return zeta_; }
    Tp inv_zeta() const { return invzeta_; }
    const std::vector<Tp> &root(int n) const {
        // 预处理 [0, n)
        assert((n & (n - 1)) == 0);
        if (const int s = root_.size(); s < n) {
            root_.resize(n);
            for (int i = __builtin_ctz(s); (1 << i) < n; ++i) {
                const int j = 1 << i;
                root_[j]    = zeta_.pow(1 << (ordlog2_ - i - 2));
                for (int k = j + 1; k < j * 2; ++k) root_[k] = root_[k - j] * root_[j];
            }
        }
        return root_;
    }
    const std::vector<Tp> &inv_root(int n) const {
        // 预处理 [0, n)
        assert((n & (n - 1)) == 0);
        if (const int s = invroot_.size(); s < n) {
            invroot_.resize(n);
            for (int i = __builtin_ctz(s); (1 << i) < n; ++i) {
                const int j = 1 << i;
                invroot_[j] = invzeta_.pow(1 << (ordlog2_ - i - 2));
                for (int k = j + 1; k < j * 2; ++k) invroot_[k] = invroot_[k - j] * invroot_[j];
            }
        }
        return invroot_;
    }
};

template <typename Tp>
class Binomial {
    std::vector<Tp> factorial_, invfactorial_;

    Binomial() : factorial_{Tp::from_raw(1)}, invfactorial_{Tp::from_raw(1)} {}

    // 预处理 [0, n)
    void preprocess(int n) {
        if (const int nn = factorial_.size(); nn < n) {
            int k = nn;
            while (k < n) k *= 2;
            factorial_.resize(k);
            invfactorial_.resize(k);
            for (int i = nn; i != k; ++i) factorial_[i] = factorial_[i - 1] * Tp::from_raw(i);
            invfactorial_.back() = factorial_.back().inv();
            for (int i = k - 2; i >= nn; --i)
                invfactorial_[i] = invfactorial_[i + 1] * Tp::from_raw(i + 1);
        }
    }

public:
    static const Binomial &get(int n) {
        static Binomial bin;
        bin.preprocess(n);
        return bin;
    }

    Tp binom(int n, int m) const {
        return n < m ? Tp() : factorial_[n] * invfactorial_[m] * invfactorial_[n - m];
    }
    Tp inv(int n) const { return factorial_[n - 1] * invfactorial_[n]; }
    Tp factorial(int n) const { return factorial_[n]; }
    Tp inv_factorial(int n) const { return invfactorial_[n]; }
};

int ntt_len(int n) {
    --n;
    n |= n >> 1, n |= n >> 2, n |= n >> 4, n |= n >> 8;
    return (n | n >> 16) + 1;
}

template <typename Tp>
void ntt(std::vector<Tp> &a) {
    const int n = a.size();
    assert((n & (n - 1)) == 0);
    for (int j = 0, l = n >> 1; j != l; ++j) {
        auto u = a[j], v = a[j + l];
        a[j] = u + v, a[j + l] = u - v;
    }
    auto &&root = NttInfo<Tp>::get().root(n / 2);
    for (int i = n >> 1; i >= 2; i >>= 1) {
        for (int j = 0, l = i >> 1; j != l; ++j) {
            auto u = a[j], v = a[j + l];
            a[j] = u + v, a[j + l] = u - v;
        }
        for (int j = i, l = i >> 1, m = 1; j != n; j += i, ++m)
            for (int k = j; k != j + l; ++k) {
                auto u = a[k], v = a[k + l] * root[m];
                a[k] = u + v, a[k + l] = u - v;
            }
    }
}

template <typename Tp>
void intt(std::vector<Tp> &a) {
    const int n = a.size();
    assert((n & (n - 1)) == 0);
    auto &&root = NttInfo<Tp>::get().inv_root(n / 2);
    for (int i = 2; i < n; i <<= 1) {
        for (int j = 0, l = i >> 1; j != l; ++j) {
            auto u = a[j], v = a[j + l];
            a[j] = u + v, a[j + l] = u - v;
        }
        for (int j = i, l = i >> 1, m = 1; j != n; j += i, ++m)
            for (int k = j; k != j + l; ++k) {
                auto u = a[k], v = a[k + l];
                a[k] = u + v, a[k + l] = (u - v) * root[m];
            }
    }
    const auto iv = Tp::from_raw(Tp::mod() - Tp::mod() / n);
    for (int j = 0, l = n >> 1; j != l; ++j) {
        auto u = a[j] * iv, v = a[j + l] * iv;
        a[j] = u + v, a[j + l] = u - v;
    }
}

template <typename Tp>
std::vector<Tp> convolution_ntt(std::vector<Tp> a, std::vector<Tp> b) {
    const int n   = a.size();
    const int m   = b.size();
    const int len = ntt_len(n + m - 1);

    a.resize(len);
    b.resize(len);
    ntt(a);
    ntt(b);
    for (int i = 0; i != len; ++i) a[i] *= b[i];
    intt(a);

    a.resize(n + m - 1);
    return a;
}

// returns f(g) mod x^n
template <typename Tp>
std::vector<Tp> composition(const std::vector<Tp> &f, const std::vector<Tp> &g, int n) {
    if (n <= 0) return {};
    if (g.empty()) return std::vector<Tp>(n);

    struct composition_rec {
        // f_[i] 是 y^(-i) 的系数
        composition_rec(const std::vector<Tp> &f, Tp g0) : f_(f), g0_(g0) {}
        // returns [y^0]P/Q mod x^(n+1)
        // 返回数组 [0,n] 放 y^(-d+1) 的系数以此类推,这样我们只需要一维数组
        std::vector<Tp> run(const std::vector<Tp> &Q, int d, int n) const {
            // Q 的 [0,n] 放 y^(-d+1) 的系数,以此类推
            if (n == 0) {
                assert(d >= f_.size()); // 需要保留全部系数
                // P[i] 是 y^(-d+1+i) 的系数
                std::vector<Tp> P(d), invQ(d);
                for (int i = d - 1, j = 0; j < (int)f_.size() && i >= 0;) P[i--] = f_[j++];
                auto &&bin = Binomial<Tp>::get(d * 2);
                // 利用二项式系数直接算出 Q^(-1) 的系数,不需要求逆
                for (int i = 0; i < d; ++i) invQ[i] = bin.binom(d + i - 1, i) * g0_.pow(i);
                auto PinvQ = convolution_ntt(P, invQ);
                PinvQ.resize(d);
                return PinvQ;
            }

            // deg_y(Q)=d => deg_y(Q(x,y)Q(-x,y))=2d, 系数有 2d+1 项
            // deg_x(Q)=n => deg_x(Q(x,y)Q(-x,y))=2n,系数有 2n+1 项
            // 令 y=x^(2n+2) 作一元卷积,仍然可以分离出所需要的系数
            // 并且这样 y 的位置就会在下标为偶数的位置
            // Q(-x,y) 的系数就不需要再计算了,因为 ntt 计算的点值是成对的,例如计算
            // ntt(f) 那么算了 f(1), f(-1), ..., 并且我们的 ntt 是为位逆序的,所以是相邻的
            const int len = ntt_len((d * 2 + 1) * (n * 2 + 2) - 1);
            // d 最小为 1,n 最小为 1 此时 len 最小为 16
            std::vector<Tp> nttQ(len);
            for (int i = 0; i <= d; ++i)
                for (int j = 0; j <= n; ++j) nttQ[i * (n * 2 + 2) + j] = Q[i * (n + 1) + j];
            ntt(nttQ);
            // 取 x^(2k) 系数只需要做一半长度的 intt 见 Bostan--Mori 的论文
            std::vector<Tp> VV(len / 2);
            for (int i = 0; i != len; i += 2) VV[i / 2] = nttQ[i] * nttQ[i + 1];
            intt(VV);
            std::vector<Tp> V((d * 2 + 1) * (n / 2 + 1));
            for (int i = 0; i <= d * 2; ++i)
                for (int j = 0; j <= n / 2; ++j) V[i * (n / 2 + 1) + j] = VV[i * (n + 1) + j];

            const auto T = run(V, d * 2, n / 2);

            // T 存了 y^(-2d+1) 到 y^0 的系数
            std::vector<Tp> nttT(len / 2);
            for (int i = 0; i < d * 2; ++i)
                for (int j = 0; j <= n / 2; ++j) nttT[i * (n + 1) + j] = T[i * (n / 2 + 1) + j];
            // 同上原因,我们只需要做一半长度
            ntt(nttT);
            std::vector<Tp> UU(len);
            for (int i = 0; i != len; i += 2) {
                UU[i]     = nttT[i / 2] * nttQ[i + 1];
                UU[i + 1] = nttT[i / 2] * nttQ[i];
            }
            intt(UU);
            std::vector<Tp> U(d * (n + 1));
            // 提取循环卷积 y^(-d+1) 到 y^0 的系数
            for (int i = 0; i < d; ++i)
                for (int j = 0; j <= n; ++j) U[i * (n + 1) + j] = UU[(i + d) * (n * 2 + 2) + j];
            return U;
        }

    private:
        const std::vector<Tp> &f_;
        const Tp g0_;
    } a(f, g[0]);

    std::vector<Tp> Q(n * 2); // [0,n)=1, [n,2n)=-g
    Q[0] = Tp::from_raw(1);
    for (int i = n, j = 0; j < (int)g.size() && i < n * 2;) Q[i++] = -g[j++];
    // 在此处 f 和 g 是多项式,因为 g(0) 不为零所以 f 不能在 mod x^n 意义下计算,需要全部系数
    auto res = a.run(Q, 1, std::max(n - 1, (int)f.size() - 1));
    res.resize(n);
    return res;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    using mint = ModInt<998244353>;

    int n, m;
    std::cin >> n >> m;

    std::vector<mint> f(n + 1), g(m + 1);
    for (int i = 0; i <= n; ++i) std::cin >> f[i];
    for (int i = 0; i <= m; ++i) std::cin >> g[i];

    auto fg = composition(f, g, n + 1);
    for (int i = 0; i <= n; ++i) std::cout << fg[i] << ' ';
    return 0;
}

参考文献