多点求值插值、转置算法以及范德蒙德矩阵逆

· · 题解

多点求值插值、转置算法以及范德蒙德矩阵逆

多点求值

给定 n 项多项式 f(x)a_0,a_1\cdots,a_{m-1},求 f(a_0),f(a_1),f(a_2)\cdots f(a_{m-1})

这就是经典的多点求值问题。我们有容易理解的 \Theta(n\log^2n) 算法:

注意到 f(a_i) \equiv f(x)\pmod{(x-a_i)},所以我们需要求出 f(x)m 个单项式取模的结果。我们分治:将求值点 0\sim m-1 均分为两部分,对于前一部分,当前多项式对 \prod\limits_{i=0}^{m/2}(x-a_i) 取模后递归,对 \prod\limits_{i=m/2+1}^{m-1}(x-a_i) 取模后递归下去即可。取模复杂度 \Theta(n\log n),所以总复杂度是 \Theta(n\log^2n) 的。这个算法容易,但是取模具有极大常数,所以总效率极低。

所以我们有更优、且理论价值更高的转置算法。

转置算法

多点求值算法的本质是相当简单的:求一个向量左乘范德蒙德矩阵

\mathrm V(a_0,\cdots,a_{m-1})=\begin{bmatrix}1 & a_0 & a_0^2 &\cdots& a_0^{n-1} \\ 1 & a_1 & a_1^2 &\cdots& a_1^{n-1} \\ \vdots&\vdots&\vdots &\ddots& \vdots \\ 1 & a_{m-1} & a_{m-1}^2 &\cdots& a_{m-1}^{n-1} \end{bmatrix}

我们观察到,矩阵 \mathrm V 的转置 \mathrm V^{T} 左乘向量是好求的:

\mathrm V^T(a_0,a_1,\cdots,a_{m-1}) = \begin{bmatrix} 1 & 1 &\cdots& 1 \\ a_0 & a_1 &\cdots& a_{m-1} \\ \vdots &\vdots &\ddots& \vdots \\ a_0^{n-1} & a_1^{n-1} &\cdots& a_{m-1}^{n-1} \end{bmatrix} ,\quad\quad g_i = \sum_{j=0}^{m-1} f_ja_j^i

这里若 \mathrm V 不是方阵,无法右乘相同的向量。所以认为 n=m,不失一般性。

G(x)g_i 的生成函数,有

G(x) = \sum_{n\ge 0}\sum_{i=0}^{m-1}f_i(a_ix)^n = \sum_{i=0}^{m-1}\frac {f_i}{1-a_ix}

较为清晰的形式下,可以采用分治求 G(x) 的前 n 项。维护分母即 Q_{l,r}(x) = \prod\limits_{i=l}^{r}(1-a_ix) 和分子 P_{l,r}(x)

对于分治区间 [l,r] 及中点 m

\begin{aligned} Q_{l,r}(x) &= Q_{l,m}(x)\times Q_{m + 1,r}(x) \\ P_{l,r}(x) &= P_{l,m}(x)\times Q_{m + 1,r}(x) + P_{m+1,r}(x)\times Q_{l,m}(x) \end{aligned}

注意到乘法和加法都不超过区间长度项,所以算法复杂度是 \Theta(n\log^2n) 的。

这个矩阵转置后的算法和我们需要的算法有什么联系呢?

考虑 A,B 矩阵,(AB)^T = B^TA^T,以及 (A_1A_2\cdots A_k)^T = A_k^TA_{k-1}^T\cdots A_0^T

考虑到多点求值整个算法是关于初始向量的线性变换,以及其中的每个步骤都是线性的;具体来说,观察分治合并过程,Q(x) 的求解和初始向量 f_i 无关,可以视作常量。而 P(x) 的求解是有关 f_i 的关于 x 的多项式与常多项式相乘再相加,为什么它是线性的呢?可以看出,[x^k]P(x) 始终具有 \sum\limits_{i=0}^{m-1}f_ic_{k,i} 的形式,c_{k,i} 是常数。这和矩阵作用下向量得到的项是相符的。

也就是说,P(x) 是一个矩阵(变换),行按照 x 的幂标号,列按照 f_i 标号。而 Q(x) 是一个(列)向量。

我们将求解转置变换 \mathrm V^T 问题的每一个较为基本的步骤逆序进行,并修改为转置过程,就可以得到原变换 \mathrm V 问题的算法,不妨称为转置算法。

我们可以相信,转置算法和原算法有相同的复杂度;因为我们进行相同多次的线性运算。

这就是转置原理,或称特勒根原理\text{Tellegen's Principle})。

现在我们尝试描述上述算法的转置算法。

首先需要说明的是,我们将如何分解矩阵 \mathrm V^T 。如果分解为初等矩阵的话,所有转置是显然的,但是出于封装和存在中间变量的关系,不妨分析将哪些步骤作为基本步骤逆序并转置。原算法的流程如下:

\begin{aligned} &1\ \ \text{Solve} (l,r): \\ &2\ \quad\quad \text{if } l = r: \text{return }f_l \\ &3\ \quad\quad m \leftarrow \lfloor(l+r)/2\rfloor \\ &4\ \quad\quad F,x,y \leftarrow 0 \quad\quad\quad\quad \text{// initial value can be given before all, doesn't count}\\ &5\ \quad\quad L \leftarrow \mathrm{Solve}(l,m) \\ &6\ \quad\quad R \leftarrow \mathrm{Solve}(m+1,r) \\ &7\ \quad\quad x \leftarrow x + L\times Q_{m+1,r}(x) \\ &8\ \quad\quad y \leftarrow y + R\times Q_{l,m}(x) \\ &9\ \quad\quad F \leftarrow F+1\times x \\ 1&0\ \quad\quad F \leftarrow F+1\times y \\ 1&1\ \quad\quad \mathrm {return\ } F \\ 1&2\ \ g \leftarrow g+\mathrm{Solve}(0,m-1)\times Q^{-1}_{0,m-1}(x) \end{aligned}

注意到我们将赋值和加法运算修改为了初等变换的形式 x \leftarrow x + c\times y(请注意 x,y,L,R 都是关于 f_i 的向量;函数返回值占用的变量和 L,R 认为是一致的,这里不存在赋值)。它的转置是 y\leftarrow y + c\times x。 这是十分重要的。

程序中有一些实际上顺序无关的过程,比如 5,6 之间,7,89,10 之间。

这是一个后序遍历的递归算法。所以我们改成先序遍历。观察到,每一层递归的 F 值事先由上一层给出了,所以改成不返回的过程,多带一个参数。尝试写出:

\begin{aligned} &1\ \ \text{Solve} (l,r,F): \\ &2\ \quad\quad \text{if } l = r: f_l\leftarrow F_0;\text{ return } \\ &3\ \quad\quad m \leftarrow \lfloor(l+r)/2\rfloor \\ &4\ \quad\quad x \leftarrow x+1\times^T F \\ &5\ \quad\quad y \leftarrow y+1\times^T F \\ &6\ \quad\quad L \leftarrow L + x\times^T Q_{m+1,r}(x) \\ &7\ \quad\quad R \leftarrow R + y\times^T Q_{l,m}(x) \\ &8\ \quad\quad \mathrm{Solve}(l,m,L) \\ &9\ \quad\quad \mathrm{Solve}(m+1,r,R) \\ 1&0\ \quad\quad \mathrm{return} \\ 1&1\ \ \mathrm{Solve}(0,m-1,g\times^T Q^{-1}_{0,m-1}(x)) \end{aligned}

形式意外地更加简单,再将 x,y 省略掉可得

\begin{aligned} &1\ \ \text{Solve} (l,r,F): \\ &2\ \quad\quad \text{if } l = r: f_l\leftarrow F_0;\text{ return } \\ &3\ \quad\quad m \leftarrow \lfloor(l+r)/2\rfloor \\ &6\ \quad\quad L \leftarrow L + F\times^T Q_{m+1,r}(x) \\ &7\ \quad\quad R \leftarrow R + F\times^T Q_{l,m}(x) \\ &8\ \quad\quad \mathrm{Solve}(l,m,L) \\ &9\ \quad\quad \mathrm{Solve}(m+1,r,R) \\ 1&0\ \quad\quad \mathrm{return} \\ 1&1\ \ \mathrm{Solve}(0,m-1,g\times^T Q^{-1}_{0,m-1}(x)) \end{aligned}

这里 \times^T 是多项式乘法,即各项卷积的转置;实际上只要把每一次「两项相乘累加」的过程置换即可,它们的顺序是无关紧要的。

c_k = \sum\limits_{i=0}^ka_{k-i}b_i,其中 a_i 是初始向量的线性组合而 b_i 是常量,那么对于过程

c_k \leftarrow c_k + a_{k-i}\times b_i

置换得到

a_{k-i} \leftarrow a_{k-i} + c_k\times b_i

这是差项卷积。但是注意到原乘法超过项数的不会有贡献,所以需要截取差卷积的 n-m+1 项(除了递归前的一次乘法),同时保证了复杂度仍是 \Theta(n\log^2n)

另外,除了分析累加以外,还有一种 DFT 转置的解法。

考虑另一种乘法的描述:DFT,对应点相乘,IDFT。

DFT 的变换矩阵 (\omega_n^i)^j = (\omega_n^j)^i = (\omega_n)^{ij},所以它是对称的,转置等于本身。 IDFT 又可以分解为一次 DFT,一次数乘和一次逆序。逆序变换的矩阵就是副对角线,也是对称的。

可以看出,转置后就是:逆序,数乘,DFT,对应点相乘,DFT。这里的若干项都是严格意义上的「变换」了,没有必要展开。

自此,我们已经完全解决了多点求值问题。

#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <vector>

const int N = 64000 * 4;
const int p = 998244353;
typedef std::vector <int> Poly;

char buf[1 << 25] ,*p1 = buf ,*p2 = buf;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf ,1 ,1 << 21 ,stdin) ,p1 == p2) ? EOF : *p1++)
inline int read() {
    int x = 0, f = 1; char ch = getchar();
    while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
    while(ch >= '0' && ch <= '9') x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}

inline int pls(int a,int b) { return a + b >= p ? a + b - p : a + b; }
inline int mus(int a,int b) { return a - b < 0 ? a - b + p : a - b; }
inline int prd(int a,int b) { return 1ll * a * b % p; }
inline int fastpow(int a,int b) {
    int r = 1;
    while(b) {
        if(b & 1) r = 1ll * r * a % p;
        a = 1ll * a * a % p;
        b >>= 1;
    }
    return r;
}

int rev[N];
void NTT(Poly &a) {
    int N = a.size();
    for(int i = 0;i < N;i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * N / 2);
    for(int i = 0;i < N;i++) if(i > rev[i]) std::swap(a[i],a[rev[i]]);
    for(int n = 2, m = 1;n <= N;m = n, n <<= 1) {
        int g1 = fastpow(3,(p - 1) / n), t1, t2;
        for(int l = 0;l < N;l += n) {
            int g = 1;
            for(int i = l;i < l + m;i++) {
                t1 = a[i], t2 = prd(a[i + m],g);
                a[i] = pls(t1,t2), a[i + m] = mus(t1,t2);
                g = prd(g,g1);
            }
        }
    }
    return;
}

void INTT(Poly &a) {
    NTT(a), std::reverse(a.begin() + 1,a.end()); 
    int invN = fastpow(a.size(),p - 2);
    for(int i = 0;i < (int)a.size();i++) a[i] = prd(a[i],invN);
}

Poly Mul(Poly a,Poly b) {
    int n = a.size() + b.size() - 1, N = 1;
    while(N < (int)(a.size() + b.size())) N <<= 1;
    a.resize(N), b.resize(N), NTT(a), NTT(b);
    for(int i = 0;i < N;i++) a[i] = prd(a[i],b[i]);
    INTT(a), a.resize(n);
    return a;
}

Poly MulT(Poly a,Poly b) {
    int n = a.size(), m = b.size(); 
    std::reverse(b.begin(),b.end()), b = Mul(a,b);
    for(int i = 0;i < n;i++) a[i] = b[i + m - 1];
    return a;
}

Poly tmp;
Poly Inv(Poly a,int n) {
    if(n == 1) return Poly(1,fastpow(a[0],p - 2));
    Poly b = Inv(a,(n + 1) / 2);
    int N = 1; while(N <= 2 * n) N <<= 1;
    tmp.resize(N), b.resize(N);
    for(int i = 0;i < n;i++) tmp[i] = a[i];
    for(int i = n;i < N;i++) tmp[i] = 0;
    NTT(tmp), NTT(b);
    for(int i = 0;i < N;i++) b[i] = mus(prd(2,b[i]),prd(prd(b[i],b[i]),tmp[i]));
    INTT(b), b.resize(n); return b;
}

int n,m; Poly a,f,v;

#define lc(k) k << 1
#define rc(k) k << 1 | 1
const int NODE = N * 4;
Poly Q[NODE];

void Init(Poly &a,int k,int l,int r) {
    if(l == r) {
        Q[k].resize(2);
        Q[k][0] = 1, Q[k][1] = mus(0,a[l]);
        return;
    }
    int m = (l + r) / 2;
    Init(a,lc(k),l,m), Init(a,rc(k),m + 1,r);
    Q[k] = Mul(Q[lc(k)],Q[rc(k)]);
    return;
}

void Multipoints(int k,int l,int r,Poly F,Poly &g) {
    F.resize(r - l + 1);
    if(l == r) return void(g[l] = F[0]);
    int m = (l + r) / 2;
    Multipoints(lc(k),l,m,MulT(F,Q[rc(k)]),g);
    Multipoints(rc(k),m + 1,r,MulT(F,Q[lc(k)]),g);
    return;
}

void Multipoint(Poly f,Poly a,Poly &v,int n) {
    f.resize(n + 1), a.resize(n);
    Init(a,1,0,n - 1), v.resize(n);
    Multipoints(1,0,n - 1,MulT(f,Inv(Q[1],n + 1)),v);
    return;
}

int main() {
    n = read() + 1, m = read();
    for(int i = 0;i < n;i++) f.push_back(read());
    for(int i = 0;i < m;i++) a.push_back(read());
    Multipoint(f,a,v,std::max(n,m));
    for(int i = 0;i < m;i++) std::printf("%d\n",v[i]);
    return 0;
}

快速插值

有拉格朗日插值公式:

f(x) = \sum_{i=0}^{n-1}y_i\prod_{j\ne i}\frac{x-x_j}{x_i-x_j}

我们要求这个式子。 直观的 \Theta(n^2) 算法是求出 F(x) = \prod\limits_{i=0}^{n-1}(x-x_i)q_i = \prod\limits_{j\ne i}(x_i-x_j),作 n 次多项式除一次式即可。

n 次多项式除单项式?我们不禁把它和刚才的分治取模多点求值算法联系起来。然而这里我们求的是除法商而非余式,所以分治取模是不对的。

考虑多项式 \dfrac{F(x)}{x-x_i},它是连续可导的,且在除 x=x_i 以外的点都和 \prod\limits_{j\ne i}(x-x_j) 值相同。所以 x-x_i 是它的可去间断点。取极限得:

\left.\prod_{j\ne i}(x-x_j)\right|_{x=x_i} = \lim_{x\to x_i}\frac{F(x)}{x-x_i} = \lim_{x\to x_i}F'(x) = F'(x_i)

使用洛必达法则。所以 q_i = F'(x_i)。分治求出 F(x) 后,可以使用多点求值得到 q_i

我们要求 f_{0,n-1}(x) = \sum\limits_{i=0}^{n-1}\dfrac{y_i}{q_i}\prod\limits_{j\ne i}(x-x_j),不妨再记 t_i = \dfrac{y_i}{q_i}。这个形式也可以直接分治计算,复杂度 \Theta(n\log^2n)

#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <vector>

const int N = 400001;
const int p = 998244353;
typedef std::vector <int> Poly;

char buf[1 << 25] ,*p1 = buf ,*p2 = buf;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf ,1 ,1 << 21 ,stdin) ,p1 == p2) ? EOF : *p1++)
inline int read() {
    int x = 0, f = 1; char ch = getchar();
    while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
    while(ch >= '0' && ch <= '9') x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}

inline int pls(int a,int b) { return a + b >= p ? a + b - p : a + b; }
inline int mus(int a,int b) { return a - b < 0 ? a - b + p : a - b; }
inline int prd(int a,int b) { return 1ll * a * b % p; }
inline int fastpow(int a,int b) {
    int r = 1;
    while(b) {
        if(b & 1) r = 1ll * r * a % p;
        a = 1ll * a * a % p;
        b >>= 1;
    }
    return r;
}

int rev[N];
void NTT(Poly &a) {
    int N = a.size();
    for(int i = 0;i < N;i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * N / 2);
    for(int i = 0;i < N;i++) if(i > rev[i]) std::swap(a[i],a[rev[i]]);
    for(int n = 2, m = 1;n <= N;m = n, n <<= 1) {
        int g1 = fastpow(3,(p - 1) / n), t1, t2;
        for(int l = 0;l < N;l += n) {
            int g = 1;
            for(int i = l;i < l + m;i++) {
                t1 = a[i], t2 = prd(a[i + m],g);
                a[i] = pls(t1,t2), a[i + m] = mus(t1,t2);
                g = prd(g,g1);
            }
        }
    }
    return;
}

void INTT(Poly &a) {
    NTT(a), std::reverse(a.begin() + 1,a.end()); 
    int invN = fastpow(a.size(),p - 2);
    for(int i = 0;i < (int)a.size();i++) a[i] = prd(a[i],invN);
}

Poly Mul(Poly a,Poly b) {
    int n = a.size() + b.size() - 1, N = 1;
    while(N < (int)(a.size() + b.size())) N <<= 1;
    a.resize(N), b.resize(N), NTT(a), NTT(b);
    for(int i = 0;i < N;i++) a[i] = prd(a[i],b[i]);
    INTT(a), a.resize(n);
    return a;
}

Poly MulT(Poly a,Poly b) {
    int n = a.size(), m = b.size(); 
    std::reverse(b.begin(),b.end()), b = Mul(a,b);
    for(int i = 0;i < n;i++) a[i] = b[i + m - 1];
    return a;
}

Poly tmp;
Poly Inv(Poly a,int n) {
    if(n == 1) return Poly(1,fastpow(a[0],p - 2));
    Poly b = Inv(a,(n + 1) / 2);
    int N = 1; while(N <= 2 * n) N <<= 1;
    tmp.resize(N), b.resize(N);
    for(int i = 0;i < n;i++) tmp[i] = a[i];
    for(int i = n;i < N;i++) tmp[i] = 0;
    NTT(tmp), NTT(b);
    for(int i = 0;i < N;i++) b[i] = mus(prd(2,b[i]),prd(prd(b[i],b[i]),tmp[i]));
    INTT(b), b.resize(n); return b;
}

Poly Dervt(Poly a) {
    for(int i = 0;i < (int)a.size() - 1;i++) a[i] = prd(i + 1,a[i + 1]);
    a.pop_back(); return a;
}

#define lc(k) k << 1
#define rc(k) k << 1 | 1
Poly Q[N];

void MultiInit(Poly &a,int k,int l,int r) {
    if(l == r) {
        Q[k].resize(2);
        Q[k][0] = 1, Q[k][1] = mus(0,a[l]);
        return;
    }
    int m = (l + r) / 2;
    MultiInit(a,lc(k),l,m), MultiInit(a,rc(k),m + 1,r);
    Q[k] = Mul(Q[lc(k)],Q[rc(k)]);
    return;
}

void Multipoints(int k,int l,int r,Poly F,Poly &g) {
    F.resize(r - l + 1);
    if(l == r) return void(g[l] = F[0]);
    int m = (l + r) / 2;
    Multipoints(lc(k),l,m,MulT(F,Q[rc(k)]),g);
    Multipoints(rc(k),m + 1,r,MulT(F,Q[lc(k)]),g);
    return;
}

void Multipoint(Poly f,Poly a,Poly &v,int n) {
    f.resize(n + 1), a.resize(n);
    MultiInit(a,1,0,n - 1), v.resize(n);
    Multipoints(1,0,n - 1,MulT(f,Inv(Q[1],n + 1)),v);
    return;
}

int n; Poly x,y,t,f;

Poly F[N];
void InterInit(int k,int l,int r) {
    if(l == r) {
        F[k].resize(2);
        F[k][0] = mus(0,x[l]), F[k][1] = 1;
        return;
    }
    int m = (l + r) / 2;
    InterInit(lc(k),l,m);
    InterInit(rc(k),m + 1,r);
    F[k] = Mul(F[lc(k)],F[rc(k)]);
    return;
}

Poly InterSolve(int k,int l,int r,Poly &t) {
    if(l == r) return Poly(1,t[l]);
    int m = (l + r) / 2;
    Poly L(InterSolve(lc(k),l,m,t));
    Poly R(InterSolve(rc(k),m + 1,r,t));
    R = Mul(R,F[lc(k)]);
    L = Mul(L,F[rc(k)]);
    for(int i = 0;i < (int)R.size();i++) L[i] = pls(L[i],R[i]);
    return L;
}

void Interpolate(Poly x,Poly y,Poly &f,int n) {
    InterInit(1,0,n - 1);
    F[1] = Dervt(F[1]), Multipoint(F[1],x,t,n);
    for(int i = 0;i < n;i++) t[i] = prd(y[i],fastpow(t[i],p - 2));
    f = InterSolve(1,0,n - 1,t);
    return;
}

int main() {
    n = read();
    for(int i = 0;i < n;i++) x.push_back(read()), y.push_back(read());
    Interpolate(x,y,f,n);
    for(int i = 0;i < n;i++) std::printf("%d ",f[i]);
    return 0;
}

范德蒙德矩阵逆

插值是求点值的逆运算,所以插值就是求范德蒙德矩阵的逆乘向量。

这个矩阵没有一个简单的形式,但是我们这里就可以通过拉格朗日公式得到

\mathrm V^{-1}_{ij} = [x^{i}]\prod\limits_{k\ne j}\dfrac{x-x_k}{x_j-x_k}

展开式相当复杂。