题解:P10249 【模板】多项式复合函数(加强版)
多项式复合的 O(n\log^2n) Bostan–Mori 算法
给出一种不需要转置原理的解释,但是为了导出整个递归算法的不变量,可能需要先理解这个算法。
给出
我们考虑多项式复合
考虑
那么我们的目标就是求算
根据 Bostan–Mori 算法有
我们先考虑计算
剩下的问题是递归算法该返回什么,考虑我们求出
那么复合的代码就是
注意一个隐藏的约束就是递归终止时的
补充
对于递归结束的计算
那么
所以最后我们只需一次多项式乘法。Thanks noshi91!
应用
我们可以发现下面的形式幂级数运算都可以通过复合解决:
-
f(0)=1,\space \dfrac{1}{f}=1+(1-f)+(1-f)^2+\cdots -
f(0)=1,\space \log f=-\dfrac{1-f}{1}-\dfrac{(1-f)^2}{2}-\dfrac{(1-f)^3}{3}-\cdots -
f(0)=0,\space \exp f=1+\dfrac{f}{1!}+\dfrac{f^2}{2!}+\dfrac{f^3}{3!}+\cdots -
f(0)=1,\space f^e=1+\dfrac{e}{1}(f-1)+\dfrac{e(e-1)}{2}(f-1)^2+\cdots
实现
我们在 C++17 标准下实现,注意 ntt 之后的结果是“位逆序”的,而 intt 的参数需要“位逆序”的,这样可以很方便的应用我们的优化。
#include <algorithm>
#include <cassert>
#include <iostream>
#include <type_traits>
#include <vector>
template <unsigned Mod>
class ModInt {
static_assert((Mod >> 31) == 0, "`Mod` must less than 2^(31)");
static unsigned safe_mod(int v) {
if ((v %= (int)Mod) < 0) v += (int)Mod;
return v;
}
struct private_constructor_t {};
static inline private_constructor_t private_constructor{};
ModInt(private_constructor_t, unsigned v) : v_(v) {}
unsigned v_;
public:
static unsigned mod() { return Mod; }
static ModInt from_raw(unsigned v) { return ModInt(private_constructor, v); }
ModInt() : v_() {}
ModInt(int v) : v_(safe_mod(v)) {}
unsigned val() const { return v_; }
ModInt operator-() const { return from_raw(v_ == 0 ? v_ : Mod - v_); }
ModInt pow(int e) const {
if (e < 0) return inv().pow(-e);
for (ModInt x(*this), res(from_raw(1));; x *= x) {
if (e & 1) res *= x;
if ((e >>= 1) == 0) return res;
}
}
ModInt inv() const {
int x1 = 1, x3 = 0, a = v_, b = Mod;
while (b) {
int q = a / b, x1_old = x1, a_old = a;
x1 = x3, x3 = x1_old - x3 * q, a = b, b = a_old - b * q;
}
return from_raw(x1 < 0 ? x1 + (int)Mod : x1);
}
std::enable_if_t<(Mod & 1), ModInt> div_by_2() const {
if ((v_ & 1) == 0) return from_raw(v_ >> 1);
return from_raw((v_ + Mod) >> 1);
}
ModInt &operator+=(const ModInt &a) {
if ((v_ += a.v_) >= Mod) v_ -= Mod;
return *this;
}
ModInt &operator-=(const ModInt &a) {
if ((v_ += Mod - a.v_) >= Mod) v_ -= Mod;
return *this;
}
ModInt &operator*=(const ModInt &a) {
v_ = (unsigned long long)v_ * a.v_ % Mod;
return *this;
}
ModInt &operator/=(const ModInt &a) { return *this *= a.inv(); }
friend ModInt operator+(const ModInt &a, const ModInt &b) { return ModInt(a) += b; }
friend ModInt operator-(const ModInt &a, const ModInt &b) { return ModInt(a) -= b; }
friend ModInt operator*(const ModInt &a, const ModInt &b) { return ModInt(a) *= b; }
friend ModInt operator/(const ModInt &a, const ModInt &b) { return ModInt(a) /= b; }
friend bool operator==(const ModInt &a, const ModInt &b) { return a.v_ == b.v_; }
friend bool operator!=(const ModInt &a, const ModInt &b) { return a.v_ != b.v_; }
friend std::istream &operator>>(std::istream &a, ModInt &b) {
int v;
a >> v;
b.v_ = safe_mod(v);
return a;
}
friend std::ostream &operator<<(std::ostream &a, const ModInt &b) { return a << b.v_; }
};
template <typename Tp>
class NttInfo {
static Tp least_quadratic_nonresidue() {
for (int i = 2;; ++i)
if (Tp::from_raw(i).pow((Tp::mod() - 1) / 2) == -1) return Tp::from_raw(i);
}
const int ordlog2_;
const Tp zeta_;
const Tp invzeta_;
mutable std::vector<Tp> root_;
mutable std::vector<Tp> invroot_;
NttInfo()
: ordlog2_(__builtin_ctz(Tp::mod() - 1)),
zeta_(least_quadratic_nonresidue().pow((Tp::mod() - 1) >> ordlog2_)),
invzeta_(zeta_.inv()), root_{Tp::from_raw(1)}, invroot_{Tp::from_raw(1)} {}
public:
static const NttInfo &get() {
static NttInfo info;
return info;
}
Tp zeta() const { return zeta_; }
Tp inv_zeta() const { return invzeta_; }
const std::vector<Tp> &root(int n) const {
// 预处理 [0, n)
assert((n & (n - 1)) == 0);
if (const int s = root_.size(); s < n) {
root_.resize(n);
for (int i = __builtin_ctz(s); (1 << i) < n; ++i) {
const int j = 1 << i;
root_[j] = zeta_.pow(1 << (ordlog2_ - i - 2));
for (int k = j + 1; k < j * 2; ++k) root_[k] = root_[k - j] * root_[j];
}
}
return root_;
}
const std::vector<Tp> &inv_root(int n) const {
// 预处理 [0, n)
assert((n & (n - 1)) == 0);
if (const int s = invroot_.size(); s < n) {
invroot_.resize(n);
for (int i = __builtin_ctz(s); (1 << i) < n; ++i) {
const int j = 1 << i;
invroot_[j] = invzeta_.pow(1 << (ordlog2_ - i - 2));
for (int k = j + 1; k < j * 2; ++k) invroot_[k] = invroot_[k - j] * invroot_[j];
}
}
return invroot_;
}
};
template <typename Tp>
class Binomial {
std::vector<Tp> factorial_, invfactorial_;
Binomial() : factorial_{Tp::from_raw(1)}, invfactorial_{Tp::from_raw(1)} {}
// 预处理 [0, n)
void preprocess(int n) {
if (const int nn = factorial_.size(); nn < n) {
int k = nn;
while (k < n) k *= 2;
factorial_.resize(k);
invfactorial_.resize(k);
for (int i = nn; i != k; ++i) factorial_[i] = factorial_[i - 1] * Tp::from_raw(i);
invfactorial_.back() = factorial_.back().inv();
for (int i = k - 2; i >= nn; --i)
invfactorial_[i] = invfactorial_[i + 1] * Tp::from_raw(i + 1);
}
}
public:
static const Binomial &get(int n) {
static Binomial bin;
bin.preprocess(n);
return bin;
}
Tp binom(int n, int m) const {
return n < m ? Tp() : factorial_[n] * invfactorial_[m] * invfactorial_[n - m];
}
Tp inv(int n) const { return factorial_[n - 1] * invfactorial_[n]; }
Tp factorial(int n) const { return factorial_[n]; }
Tp inv_factorial(int n) const { return invfactorial_[n]; }
};
int ntt_len(int n) {
--n;
n |= n >> 1, n |= n >> 2, n |= n >> 4, n |= n >> 8;
return (n | n >> 16) + 1;
}
template <typename Tp>
void ntt(std::vector<Tp> &a) {
const int n = a.size();
assert((n & (n - 1)) == 0);
for (int j = 0, l = n >> 1; j != l; ++j) {
auto u = a[j], v = a[j + l];
a[j] = u + v, a[j + l] = u - v;
}
auto &&root = NttInfo<Tp>::get().root(n / 2);
for (int i = n >> 1; i >= 2; i >>= 1) {
for (int j = 0, l = i >> 1; j != l; ++j) {
auto u = a[j], v = a[j + l];
a[j] = u + v, a[j + l] = u - v;
}
for (int j = i, l = i >> 1, m = 1; j != n; j += i, ++m)
for (int k = j; k != j + l; ++k) {
auto u = a[k], v = a[k + l] * root[m];
a[k] = u + v, a[k + l] = u - v;
}
}
}
template <typename Tp>
void intt(std::vector<Tp> &a) {
const int n = a.size();
assert((n & (n - 1)) == 0);
auto &&root = NttInfo<Tp>::get().inv_root(n / 2);
for (int i = 2; i < n; i <<= 1) {
for (int j = 0, l = i >> 1; j != l; ++j) {
auto u = a[j], v = a[j + l];
a[j] = u + v, a[j + l] = u - v;
}
for (int j = i, l = i >> 1, m = 1; j != n; j += i, ++m)
for (int k = j; k != j + l; ++k) {
auto u = a[k], v = a[k + l];
a[k] = u + v, a[k + l] = (u - v) * root[m];
}
}
const auto iv = Tp::from_raw(Tp::mod() - Tp::mod() / n);
for (int j = 0, l = n >> 1; j != l; ++j) {
auto u = a[j] * iv, v = a[j + l] * iv;
a[j] = u + v, a[j + l] = u - v;
}
}
template <typename Tp>
std::vector<Tp> convolution_ntt(std::vector<Tp> a, std::vector<Tp> b) {
const int n = a.size();
const int m = b.size();
const int len = ntt_len(n + m - 1);
a.resize(len);
b.resize(len);
ntt(a);
ntt(b);
for (int i = 0; i != len; ++i) a[i] *= b[i];
intt(a);
a.resize(n + m - 1);
return a;
}
// returns f(g) mod x^n
template <typename Tp>
std::vector<Tp> composition(const std::vector<Tp> &f, const std::vector<Tp> &g, int n) {
if (n <= 0) return {};
if (g.empty()) return std::vector<Tp>(n);
struct composition_rec {
// f_[i] 是 y^(-i) 的系数
composition_rec(const std::vector<Tp> &f, Tp g0) : f_(f), g0_(g0) {}
// returns [y^0]P/Q mod x^(n+1)
// 返回数组 [0,n] 放 y^(-d+1) 的系数以此类推,这样我们只需要一维数组
std::vector<Tp> run(const std::vector<Tp> &Q, int d, int n) const {
// Q 的 [0,n] 放 y^(-d+1) 的系数,以此类推
if (n == 0) {
assert(d >= f_.size()); // 需要保留全部系数
// P[i] 是 y^(-d+1+i) 的系数
std::vector<Tp> P(d), invQ(d);
for (int i = d - 1, j = 0; j < (int)f_.size() && i >= 0;) P[i--] = f_[j++];
auto &&bin = Binomial<Tp>::get(d * 2);
// 利用二项式系数直接算出 Q^(-1) 的系数,不需要求逆
for (int i = 0; i < d; ++i) invQ[i] = bin.binom(d + i - 1, i) * g0_.pow(i);
auto PinvQ = convolution_ntt(P, invQ);
PinvQ.resize(d);
return PinvQ;
}
// deg_y(Q)=d => deg_y(Q(x,y)Q(-x,y))=2d, 系数有 2d+1 项
// deg_x(Q)=n => deg_x(Q(x,y)Q(-x,y))=2n,系数有 2n+1 项
// 令 y=x^(2n+2) 作一元卷积,仍然可以分离出所需要的系数
// 并且这样 y 的位置就会在下标为偶数的位置
// Q(-x,y) 的系数就不需要再计算了,因为 ntt 计算的点值是成对的,例如计算
// ntt(f) 那么算了 f(1), f(-1), ..., 并且我们的 ntt 是为位逆序的,所以是相邻的
const int len = ntt_len((d * 2 + 1) * (n * 2 + 2) - 1);
// d 最小为 1,n 最小为 1 此时 len 最小为 16
std::vector<Tp> nttQ(len);
for (int i = 0; i <= d; ++i)
for (int j = 0; j <= n; ++j) nttQ[i * (n * 2 + 2) + j] = Q[i * (n + 1) + j];
ntt(nttQ);
// 取 x^(2k) 系数只需要做一半长度的 intt 见 Bostan--Mori 的论文
std::vector<Tp> VV(len / 2);
for (int i = 0; i != len; i += 2) VV[i / 2] = nttQ[i] * nttQ[i + 1];
intt(VV);
std::vector<Tp> V((d * 2 + 1) * (n / 2 + 1));
for (int i = 0; i <= d * 2; ++i)
for (int j = 0; j <= n / 2; ++j) V[i * (n / 2 + 1) + j] = VV[i * (n + 1) + j];
const auto T = run(V, d * 2, n / 2);
// T 存了 y^(-2d+1) 到 y^0 的系数
std::vector<Tp> nttT(len / 2);
for (int i = 0; i < d * 2; ++i)
for (int j = 0; j <= n / 2; ++j) nttT[i * (n + 1) + j] = T[i * (n / 2 + 1) + j];
// 同上原因,我们只需要做一半长度
ntt(nttT);
std::vector<Tp> UU(len);
for (int i = 0; i != len; i += 2) {
UU[i] = nttT[i / 2] * nttQ[i + 1];
UU[i + 1] = nttT[i / 2] * nttQ[i];
}
intt(UU);
std::vector<Tp> U(d * (n + 1));
// 提取循环卷积 y^(-d+1) 到 y^0 的系数
for (int i = 0; i < d; ++i)
for (int j = 0; j <= n; ++j) U[i * (n + 1) + j] = UU[(i + d) * (n * 2 + 2) + j];
return U;
}
private:
const std::vector<Tp> &f_;
const Tp g0_;
} a(f, g[0]);
std::vector<Tp> Q(n * 2); // [0,n)=1, [n,2n)=-g
Q[0] = Tp::from_raw(1);
for (int i = n, j = 0; j < (int)g.size() && i < n * 2;) Q[i++] = -g[j++];
// 在此处 f 和 g 是多项式,因为 g(0) 不为零所以 f 不能在 mod x^n 意义下计算,需要全部系数
auto res = a.run(Q, 1, std::max(n - 1, (int)f.size() - 1));
res.resize(n);
return res;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
using mint = ModInt<998244353>;
int n, m;
std::cin >> n >> m;
std::vector<mint> f(n + 1), g(m + 1);
for (int i = 0; i <= n; ++i) std::cin >> f[i];
for (int i = 0; i <= m; ++i) std::cin >> g[i];
auto fg = composition(f, g, n + 1);
for (int i = 0; i <= n; ++i) std::cout << fg[i] << ' ';
return 0;
}
参考文献
- Alin Bostan, Ryuhei Mori. A Simple and Fast Algorithm for Computing the N-th Term of a Linearly Recurrent Sequence.
- noshi91. FPS の合成と逆関数、冪乗の係数列挙 Θ(n (log(n))^2).