多点求值插值、转置算法以及范德蒙德矩阵逆
Vocalise
2022-01-10 10:27:09
## 多点求值插值、转置算法以及范德蒙德矩阵逆
### 多点求值
给定 $n$ 项多项式 $f(x)$ 及 $a_0,a_1\cdots,a_{m-1}$,求 $f(a_0),f(a_1),f(a_2)\cdots f(a_{m-1})$。
这就是经典的多点求值问题。我们有容易理解的 $\Theta(n\log^2n)$ 算法:
注意到 $f(a_i) \equiv f(x)\pmod{(x-a_i)}$,所以我们需要求出 $f(x)$ 对 $m$ 个单项式取模的结果。我们分治:将求值点 $0\sim m-1$ 均分为两部分,对于前一部分,当前多项式对 $\prod\limits_{i=0}^{m/2}(x-a_i)$ 取模后递归,对 $\prod\limits_{i=m/2+1}^{m-1}(x-a_i)$ 取模后递归下去即可。取模复杂度 $\Theta(n\log n)$,所以总复杂度是 $\Theta(n\log^2n)$ 的。这个算法容易,但是取模具有极大常数,所以总效率极低。
所以我们有更优、且理论价值更高的转置算法。
### 转置算法
多点求值算法的本质是相当简单的:求一个向量左乘范德蒙德矩阵
$$
\mathrm V(a_0,\cdots,a_{m-1})=\begin{bmatrix}1 & a_0 & a_0^2 &\cdots& a_0^{n-1} \\ 1 & a_1 & a_1^2 &\cdots& a_1^{n-1} \\ \vdots&\vdots&\vdots &\ddots& \vdots \\ 1 & a_{m-1} & a_{m-1}^2 &\cdots& a_{m-1}^{n-1} \end{bmatrix}
$$
我们观察到,矩阵 $\mathrm V$ 的转置 $\mathrm V^{T}$ 左乘向量是好求的:
$$
\mathrm V^T(a_0,a_1,\cdots,a_{m-1}) =
\begin{bmatrix}
1 & 1 &\cdots& 1 \\
a_0 & a_1 &\cdots& a_{m-1} \\
\vdots &\vdots &\ddots& \vdots \\
a_0^{n-1} & a_1^{n-1} &\cdots& a_{m-1}^{n-1}
\end{bmatrix}
,\quad\quad g_i = \sum_{j=0}^{m-1} f_ja_j^i
$$
这里若 $\mathrm V$ 不是方阵,无法右乘相同的向量。所以认为 $n=m$,不失一般性。
记 $G(x)$ 为 $g_i$ 的生成函数,有
$$
G(x) = \sum_{n\ge 0}\sum_{i=0}^{m-1}f_i(a_ix)^n = \sum_{i=0}^{m-1}\frac {f_i}{1-a_ix}
$$
较为清晰的形式下,可以采用分治求 $G(x)$ 的前 $n$ 项。维护分母即 $Q_{l,r}(x) = \prod\limits_{i=l}^{r}(1-a_ix)$ 和分子 $P_{l,r}(x)$。
对于分治区间 $[l,r]$ 及中点 $m$,
$$
\begin{aligned}
Q_{l,r}(x) &= Q_{l,m}(x)\times Q_{m + 1,r}(x) \\
P_{l,r}(x) &= P_{l,m}(x)\times Q_{m + 1,r}(x) + P_{m+1,r}(x)\times Q_{l,m}(x)
\end{aligned}
$$
注意到乘法和加法都不超过区间长度项,所以算法复杂度是 $\Theta(n\log^2n)$ 的。
这个矩阵转置后的算法和我们需要的算法有什么联系呢?
考虑 $A,B$ 矩阵,$(AB)^T = B^TA^T$,以及 $(A_1A_2\cdots A_k)^T = A_k^TA_{k-1}^T\cdots A_0^T$。
考虑到多点求值整个算法是**关于初始向量**的线性变换,以及其中的每个步骤都是线性的;具体来说,观察分治合并过程,$Q(x)$ 的求解和初始向量 $f_i$ 无关,可以视作常量。而 $P(x)$ 的求解是有关 $f_i$ 的关于 $x$ 的多项式与常多项式相乘再相加,为什么它是线性的呢?可以看出,$[x^k]P(x)$ 始终具有 $\sum\limits_{i=0}^{m-1}f_ic_{k,i}$ 的形式,$c_{k,i}$ 是常数。这和矩阵作用下向量得到的项是相符的。
**也就是说,$P(x)$ 是一个矩阵(变换),行按照 $x$ 的幂标号,列按照 $f_i$ 标号。而 $Q(x)$ 是一个(列)向量。**
我们将求解转置变换 $\mathrm V^T$ 问题的每一个较为基本的步骤逆序进行,并修改为转置过程,就可以得到原变换 $\mathrm V$ 问题的算法,不妨称为转置算法。
我们可以相信,转置算法和原算法有相同的复杂度;因为我们进行相同多次的线性运算。
这就是**转置原理**,或称**特勒根原理**($\text{Tellegen's Principle}$)。
---
现在我们尝试描述上述算法的转置算法。
首先需要说明的是,我们将如何分解矩阵 $\mathrm V^T$ 。如果分解为初等矩阵的话,所有转置是显然的,但是出于封装和存在中间变量的关系,不妨分析将哪些步骤作为基本步骤逆序并转置。原算法的流程如下:
$$
\begin{aligned}
&1\ \ \text{Solve} (l,r): \\
&2\ \quad\quad \text{if } l = r: \text{return }f_l \\
&3\ \quad\quad m \leftarrow \lfloor(l+r)/2\rfloor \\
&4\ \quad\quad F,x,y \leftarrow 0 \quad\quad\quad\quad \text{// initial value can be given before all, doesn't count}\\
&5\ \quad\quad L \leftarrow \mathrm{Solve}(l,m) \\
&6\ \quad\quad R \leftarrow \mathrm{Solve}(m+1,r) \\
&7\ \quad\quad x \leftarrow x + L\times Q_{m+1,r}(x) \\
&8\ \quad\quad y \leftarrow y + R\times Q_{l,m}(x) \\
&9\ \quad\quad F \leftarrow F+1\times x \\
1&0\ \quad\quad F \leftarrow F+1\times y \\
1&1\ \quad\quad \mathrm {return\ } F \\
1&2\ \ g \leftarrow g+\mathrm{Solve}(0,m-1)\times Q_{0,m-1}(x)
\end{aligned}
$$
注意到我们将赋值和加法运算修改为了初等变换的形式 $x \leftarrow x + c\times y$(请注意 $x,y,L,R$ 都是关于 $f_i$ 的向量;函数返回值占用的变量和 $L,R$ 认为是一致的,这里不存在赋值)。它的转置是 $y\leftarrow y + c\times x$。 这是十分重要的。
程序中有一些实际上顺序无关的过程,比如 $5,6$ 之间,$7,8$ 及 $9,10$ 之间。
这是一个后序遍历的递归算法。所以我们改成先序遍历。观察到,**每一层递归的 $F$ 值事先由上一层给出了**,所以改成不返回的过程,多带一个参数。尝试写出:
$$
\begin{aligned}
&1\ \ \text{Solve} (l,r,F): \\
&2\ \quad\quad \text{if } l = r: f_l\leftarrow F_0;\text{ return } \\
&3\ \quad\quad m \leftarrow \lfloor(l+r)/2\rfloor \\
&4\ \quad\quad x \leftarrow x+1\times^T F \\
&5\ \quad\quad y \leftarrow y+1\times^T F \\
&6\ \quad\quad L \leftarrow L + x\times^T Q_{m+1,r}(x) \\
&7\ \quad\quad R \leftarrow R + y\times^T Q_{l,m}(x) \\
&8\ \quad\quad \mathrm{Solve}(l,m,L) \\
&9\ \quad\quad \mathrm{Solve}(m+1,r,R) \\
1&0\ \quad\quad \mathrm{return} \\
1&1\ \ \mathrm{Solve}(0,m-1,g\times^T Q_{0,m-1}(x))
\end{aligned}
$$
形式意外地更加简单,再将 $x,y$ 省略掉可得
$$
\begin{aligned}
&1\ \ \text{Solve} (l,r,F): \\
&2\ \quad\quad \text{if } l = r: f_l\leftarrow F_0;\text{ return } \\
&3\ \quad\quad m \leftarrow \lfloor(l+r)/2\rfloor \\
&6\ \quad\quad L \leftarrow L + F\times^T Q_{m+1,r}(x) \\
&7\ \quad\quad R \leftarrow R + F\times^T Q_{l,m}(x) \\
&8\ \quad\quad \mathrm{Solve}(l,m,L) \\
&9\ \quad\quad \mathrm{Solve}(m+1,r,R) \\
1&0\ \quad\quad \mathrm{return} \\
1&1\ \ \mathrm{Solve}(0,m-1,g\times^T Q_{0,m-1}(x))
\end{aligned}
$$
这里 $\times^T$ 是多项式乘法,即各项卷积的转置;实际上只要把每一次「两项相乘累加」的过程置换即可,它们的顺序是无关紧要的。
设 $c_k = \sum\limits_{i=0}^ka_{k-i}b_i$,其中 $a_i$ 是初始向量的线性组合而 $b_i$ 是常量,那么对于过程
$$
c_k \leftarrow c_k + a_{k-i}\times b_i
$$
置换得到
$$
a_{k-i} \leftarrow a_{k-i} + c_k\times b_i
$$
这是差项卷积。但是注意到原乘法超过项数的不会有贡献,所以需要截取差卷积的 $n-m+1$ 项(除了递归前的一次乘法),同时保证了复杂度仍是 $\Theta(n\log^2n)$。
另外,除了分析累加以外,还有一种 DFT 转置的解法。
考虑另一种乘法的描述:DFT,对应点相乘,IDFT。
DFT 的变换矩阵 $(\omega_n^i)^j = (\omega_n^j)^i = (\omega_n)^{ij}$,所以它是对称的,转置等于本身。 IDFT 又可以分解为一次 DFT,一次数乘和一次逆序。逆序变换的矩阵就是副对角线,也是对称的。
可以看出,转置后就是:逆序,数乘,DFT,对应点相乘,DFT。这里的若干项都是严格意义上的「变换」了,没有必要展开。
自此,我们已经完全解决了多点求值问题。
```c++
#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <vector>
const int N = 64000 * 4;
const int p = 998244353;
typedef std::vector <int> Poly;
char buf[1 << 25] ,*p1 = buf ,*p2 = buf;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf ,1 ,1 << 21 ,stdin) ,p1 == p2) ? EOF : *p1++)
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
while(ch >= '0' && ch <= '9') x = x * 10 + ch - 48, ch = getchar();
return x * f;
}
inline int pls(int a,int b) { return a + b >= p ? a + b - p : a + b; }
inline int mus(int a,int b) { return a - b < 0 ? a - b + p : a - b; }
inline int prd(int a,int b) { return 1ll * a * b % p; }
inline int fastpow(int a,int b) {
int r = 1;
while(b) {
if(b & 1) r = 1ll * r * a % p;
a = 1ll * a * a % p;
b >>= 1;
}
return r;
}
int rev[N];
void NTT(Poly &a) {
int N = a.size();
for(int i = 0;i < N;i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * N / 2);
for(int i = 0;i < N;i++) if(i > rev[i]) std::swap(a[i],a[rev[i]]);
for(int n = 2, m = 1;n <= N;m = n, n <<= 1) {
int g1 = fastpow(3,(p - 1) / n), t1, t2;
for(int l = 0;l < N;l += n) {
int g = 1;
for(int i = l;i < l + m;i++) {
t1 = a[i], t2 = prd(a[i + m],g);
a[i] = pls(t1,t2), a[i + m] = mus(t1,t2);
g = prd(g,g1);
}
}
}
return;
}
void INTT(Poly &a) {
NTT(a), std::reverse(a.begin() + 1,a.end());
int invN = fastpow(a.size(),p - 2);
for(int i = 0;i < (int)a.size();i++) a[i] = prd(a[i],invN);
}
Poly Mul(Poly a,Poly b) {
int n = a.size() + b.size() - 1, N = 1;
while(N < (int)(a.size() + b.size())) N <<= 1;
a.resize(N), b.resize(N), NTT(a), NTT(b);
for(int i = 0;i < N;i++) a[i] = prd(a[i],b[i]);
INTT(a), a.resize(n);
return a;
}
Poly MulT(Poly a,Poly b) {
int n = a.size(), m = b.size();
std::reverse(b.begin(),b.end()), b = Mul(a,b);
for(int i = 0;i < n;i++) a[i] = b[i + m - 1];
return a;
}
Poly tmp;
Poly Inv(Poly a,int n) {
if(n == 1) return Poly(1,fastpow(a[0],p - 2));
Poly b = Inv(a,(n + 1) / 2);
int N = 1; while(N <= 2 * n) N <<= 1;
tmp.resize(N), b.resize(N);
for(int i = 0;i < n;i++) tmp[i] = a[i];
for(int i = n;i < N;i++) tmp[i] = 0;
NTT(tmp), NTT(b);
for(int i = 0;i < N;i++) b[i] = mus(prd(2,b[i]),prd(prd(b[i],b[i]),tmp[i]));
INTT(b), b.resize(n); return b;
}
int n,m; Poly a,f,v;
#define lc(k) k << 1
#define rc(k) k << 1 | 1
const int NODE = N * 4;
Poly Q[NODE];
void Init(Poly &a,int k,int l,int r) {
if(l == r) {
Q[k].resize(2);
Q[k][0] = 1, Q[k][1] = mus(0,a[l]);
return;
}
int m = (l + r) / 2;
Init(a,lc(k),l,m), Init(a,rc(k),m + 1,r);
Q[k] = Mul(Q[lc(k)],Q[rc(k)]);
return;
}
void Multipoints(int k,int l,int r,Poly F,Poly &g) {
F.resize(r - l + 1);
if(l == r) return void(g[l] = F[0]);
int m = (l + r) / 2;
Multipoints(lc(k),l,m,MulT(F,Q[rc(k)]),g);
Multipoints(rc(k),m + 1,r,MulT(F,Q[lc(k)]),g);
return;
}
void Multipoint(Poly f,Poly a,Poly &v,int n) {
f.resize(n + 1), a.resize(n);
Init(a,1,0,n - 1), v.resize(n);
Multipoints(1,0,n - 1,MulT(f,Inv(Q[1],n + 1)),v);
return;
}
int main() {
n = read() + 1, m = read();
for(int i = 0;i < n;i++) f.push_back(read());
for(int i = 0;i < m;i++) a.push_back(read());
Multipoint(f,a,v,std::max(n,m));
for(int i = 0;i < m;i++) std::printf("%d\n",v[i]);
return 0;
}
```
## 快速插值
有拉格朗日插值公式:
$$
f(x) = \sum_{i=0}^{n-1}y_i\prod_{j\ne i}\frac{x-x_j}{x_i-x_j}
$$
我们要求这个式子。 直观的 $\Theta(n^2)$ 算法是求出 $F(x) = \prod\limits_{i=0}^{n-1}(x-x_i)$ 和 $q_i = \prod\limits_{j\ne i}(x_i-x_j)$,作 $n$ 次多项式除一次式即可。
求 $n$ 次多项式除单项式?我们不禁把它和刚才的分治取模多点求值算法联系起来。然而这里我们求的是除法商而非余式,所以分治取模是不对的。
考虑多项式 $\dfrac{F(x)}{x-x_i}$,它是连续可导的,且在除 $x=x_i$ 以外的点都和 $\prod\limits_{j\ne i}(x-x_j)$ 值相同。所以 $x-x_i$ 是它的可去间断点。取极限得:
$$
\left.\prod_{j\ne i}(x-x_j)\right|_{x=x_i} = \lim_{x\to x_i}\frac{F(x)}{x-x_i} = \lim_{x\to x_i}F'(x) = F'(x_i)
$$
使用洛必达法则。所以 $q_i = F'(x_i)$。分治求出 $F(x)$ 后,可以使用多点求值得到 $q_i$。
我们要求 $f_{0,n-1}(x) = \sum\limits_{i=0}^{n-1}\dfrac{y_i}{q_i}\prod\limits_{j\ne i}(x-x_j)$,不妨再记 $t_i = \dfrac{y_i}{q_i}$。这个形式也可以直接分治计算,复杂度 $\Theta(n\log^2n)$。
```c++
#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <vector>
const int N = 400001;
const int p = 998244353;
typedef std::vector <int> Poly;
char buf[1 << 25] ,*p1 = buf ,*p2 = buf;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf ,1 ,1 << 21 ,stdin) ,p1 == p2) ? EOF : *p1++)
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
while(ch >= '0' && ch <= '9') x = x * 10 + ch - 48, ch = getchar();
return x * f;
}
inline int pls(int a,int b) { return a + b >= p ? a + b - p : a + b; }
inline int mus(int a,int b) { return a - b < 0 ? a - b + p : a - b; }
inline int prd(int a,int b) { return 1ll * a * b % p; }
inline int fastpow(int a,int b) {
int r = 1;
while(b) {
if(b & 1) r = 1ll * r * a % p;
a = 1ll * a * a % p;
b >>= 1;
}
return r;
}
int rev[N];
void NTT(Poly &a) {
int N = a.size();
for(int i = 0;i < N;i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * N / 2);
for(int i = 0;i < N;i++) if(i > rev[i]) std::swap(a[i],a[rev[i]]);
for(int n = 2, m = 1;n <= N;m = n, n <<= 1) {
int g1 = fastpow(3,(p - 1) / n), t1, t2;
for(int l = 0;l < N;l += n) {
int g = 1;
for(int i = l;i < l + m;i++) {
t1 = a[i], t2 = prd(a[i + m],g);
a[i] = pls(t1,t2), a[i + m] = mus(t1,t2);
g = prd(g,g1);
}
}
}
return;
}
void INTT(Poly &a) {
NTT(a), std::reverse(a.begin() + 1,a.end());
int invN = fastpow(a.size(),p - 2);
for(int i = 0;i < (int)a.size();i++) a[i] = prd(a[i],invN);
}
Poly Mul(Poly a,Poly b) {
int n = a.size() + b.size() - 1, N = 1;
while(N < (int)(a.size() + b.size())) N <<= 1;
a.resize(N), b.resize(N), NTT(a), NTT(b);
for(int i = 0;i < N;i++) a[i] = prd(a[i],b[i]);
INTT(a), a.resize(n);
return a;
}
Poly MulT(Poly a,Poly b) {
int n = a.size(), m = b.size();
std::reverse(b.begin(),b.end()), b = Mul(a,b);
for(int i = 0;i < n;i++) a[i] = b[i + m - 1];
return a;
}
Poly tmp;
Poly Inv(Poly a,int n) {
if(n == 1) return Poly(1,fastpow(a[0],p - 2));
Poly b = Inv(a,(n + 1) / 2);
int N = 1; while(N <= 2 * n) N <<= 1;
tmp.resize(N), b.resize(N);
for(int i = 0;i < n;i++) tmp[i] = a[i];
for(int i = n;i < N;i++) tmp[i] = 0;
NTT(tmp), NTT(b);
for(int i = 0;i < N;i++) b[i] = mus(prd(2,b[i]),prd(prd(b[i],b[i]),tmp[i]));
INTT(b), b.resize(n); return b;
}
Poly Dervt(Poly a) {
for(int i = 0;i < (int)a.size() - 1;i++) a[i] = prd(i + 1,a[i + 1]);
a.pop_back(); return a;
}
#define lc(k) k << 1
#define rc(k) k << 1 | 1
Poly Q[N];
void MultiInit(Poly &a,int k,int l,int r) {
if(l == r) {
Q[k].resize(2);
Q[k][0] = 1, Q[k][1] = mus(0,a[l]);
return;
}
int m = (l + r) / 2;
MultiInit(a,lc(k),l,m), MultiInit(a,rc(k),m + 1,r);
Q[k] = Mul(Q[lc(k)],Q[rc(k)]);
return;
}
void Multipoints(int k,int l,int r,Poly F,Poly &g) {
F.resize(r - l + 1);
if(l == r) return void(g[l] = F[0]);
int m = (l + r) / 2;
Multipoints(lc(k),l,m,MulT(F,Q[rc(k)]),g);
Multipoints(rc(k),m + 1,r,MulT(F,Q[lc(k)]),g);
return;
}
void Multipoint(Poly f,Poly a,Poly &v,int n) {
f.resize(n + 1), a.resize(n);
MultiInit(a,1,0,n - 1), v.resize(n);
Multipoints(1,0,n - 1,MulT(f,Inv(Q[1],n + 1)),v);
return;
}
int n; Poly x,y,t,f;
Poly F[N];
void InterInit(int k,int l,int r) {
if(l == r) {
F[k].resize(2);
F[k][0] = mus(0,x[l]), F[k][1] = 1;
return;
}
int m = (l + r) / 2;
InterInit(lc(k),l,m);
InterInit(rc(k),m + 1,r);
F[k] = Mul(F[lc(k)],F[rc(k)]);
return;
}
Poly InterSolve(int k,int l,int r,Poly &t) {
if(l == r) return Poly(1,t[l]);
int m = (l + r) / 2;
Poly L(InterSolve(lc(k),l,m,t));
Poly R(InterSolve(rc(k),m + 1,r,t));
R = Mul(R,F[lc(k)]);
L = Mul(L,F[rc(k)]);
for(int i = 0;i < (int)R.size();i++) L[i] = pls(L[i],R[i]);
return L;
}
void Interpolate(Poly x,Poly y,Poly &f,int n) {
InterInit(1,0,n - 1);
F[1] = Dervt(F[1]), Multipoint(F[1],x,t,n);
for(int i = 0;i < n;i++) t[i] = prd(y[i],fastpow(t[i],p - 2));
f = InterSolve(1,0,n - 1,t);
return;
}
int main() {
n = read();
for(int i = 0;i < n;i++) x.push_back(read()), y.push_back(read());
Interpolate(x,y,f,n);
for(int i = 0;i < n;i++) std::printf("%d ",f[i]);
return 0;
}
```
### 范德蒙德矩阵逆
插值是求点值的逆运算,所以插值就是求范德蒙德矩阵的逆乘向量。
这个矩阵没有一个简单的形式,但是我们这里就可以通过拉格朗日公式得到
$$
\mathrm V^{-1}_{ij} = [x^{i}]\prod\limits_{k\ne j}\dfrac{x-x_k}{x_j-x_k}
$$
展开式相当复杂。