题解 P5158 【【模板】多项式快速插值】

· · 题解

咕咕~ 本蒟蒻是看到新模板后现学的,,,于是可能有描述不准确的地方,请见谅。

0. 多项式多点求值

这是做着题的基本,详见P5050,蒟蒻在这里就不详细讲了。

1. 假·解法

### 2. 真·解法 然而还是拉格朗日插值。。。即,$f(x)=\sum_{i=0}^{n-1}\frac {\prod_{j\ne i}(x-x_j)} {\prod_{j\ne i}(x_i-x_j)}y_i
  1. \prod_{j\ne i}(x_i-x_j)

    这是要先计算的。

    L(x)=\prod_{i=0}^{n-1} (x-x_i),R_i(x)=\frac{L(x)}{x-x_i},则有\prod_{j\ne i}(x_i-x_j)=R_i(x_i),我们发现此时分式R_i分子分母都为0,于是有\prod_{j\ne i}(x_i-x_j)=R_i(x_i)=L'(x_i) ^{\tiny{\textcolor{grey}{\text{*根据洛必达法则}}}}

    于是此部分可用多项式多点求值求得。

  2. 整体

    上面那部分求完了,就可以直接求整体了。

    经上一部分的转化,我们现在要求f(x)=\sum_{i=0}^{n-1}\frac{y_i}{\prod_{j\ne i}(x_i-x_j)}\prod_{j\ne i}(x-x_j)

    然后我们可以NTT一下,于是超长公式警告

    f(x)=\Large[\normalsize\prod_{i=\lfloor\frac{n}{2}\rfloor+1}^{n-1}(x-x_i)\Large]\{\normalsize\sum_{i=0}^{\lfloor\frac{n}{2}\rfloor}\Large[\normalsize\frac{y_i}{\prod_{j\ne i}(x_i-x_j)}\Large][\normalsize\prod_{0\le j\le \lfloor\frac{n}{2}\rfloor,j\ne i}(x-x_j)\Large]\}\normalsize+\Large[\normalsize\prod_{i=0}^{\lfloor\frac{n}{2}\rfloor}(x-x_i)\Large]\normalsize\Large[\normalsize\sum_{i=\lfloor\frac{n}{2}\rfloor+1}^{n-1}\quad \prod_{\lfloor\frac{n}{2}\rfloor<j<n,j\ne i}(x-x_j)\Large]\normalsize

    个人觉得这个公式被强制性加括号后好读了许多。。。

总时间复杂度O(nlogn)

代码如下(封装度可高了QAQ):

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int P = 998244353;

int n, m;
vector<int> x, y, z, ans;
vector<vector<int>> p;

namespace commonTool{
    inline int power(ll x, int y=P-2){
        int ans=1;
        for(; y; y>>=1, x=x*x%P) if(y&1) ans=ans*x%P;
        return ans;
    }

    inline int power2to(int x){
        int n=1;
        while(n<=x) n<<=1;
        return n;
    }

    inline int mod(int x){
        return x<P?x:x-P;
    }
    inline void NTT(vector<int> &f, int g, int n){
        using namespace commonTool;
        f.resize(n);
        for(int i=0, j=0; i<n; ++i){
            if(i>j) swap(f[i], f[j]);
            for(int k=n>>1; (j^=k)<k; k>>=1);
        }
        vector<int> w(n>>1);
        for(int i=1; i<n; i<<=1){
            for(int j=w[0]=1, w0=(g==1?power(3, (P-1)/i/2):power(power(3, (P-1)/i/2))); j<i; ++j) w[j]=(ll)w[j-1]*w0%P;
            for(int j=0; j<n; j+=i<<1){
                for(int k=j; k<j+i; ++k){
                    int t=(ll)f[k+i]*w[k-j]%P;
                    f[k+i]=mod(f[k]-t+P);
                    f[k]=mod(f[k]+t);
                }
            }
        }
        if(g==-1) for(int i=0, I=power(n); i<n; ++i) f[i]=(ll)f[i]*I%P;
    }
}

namespace sptInter{
    inline vector<int> add(const vector<int> &f, const vector<int> &g){
        vector<int> ans=f;
        for(unsigned i=0; i<f.size(); ++i) (ans[i]+=g[i])%=P;
        return ans;
    }

    inline vector<int> multi(const vector<int> &f, const vector<int> &g){
        using namespace commonTool;
        vector<int> F=f, G=g;
        int p=power2to(f.size()+g.size()-2);
        NTT(F, 1, p), NTT(G, 1, p);
        for(int i=0; i<p; ++i) F[i]=(ll)F[i]*G[i]%P;
        NTT(F, -1, p);
        return F.resize(f.size()+g.size()-1), F;
    }

    inline vector<int> inv(const vector<int> &f, int n=-1){
        using namespace commonTool;
        if(n==-1) n=f.size();
        vector<int> ans;
        if(n==1) return ans.push_back(power(f[0])), ans;
        ans=inv(f, (n+1)/2);
        vector<int> tmp(&f[0], &f[0]+n);
        int p=power2to(n*2-2);
        NTT(tmp, 1, p), NTT(ans, 1, p);
        for(int i=0; i<p; ++i) ans[i]=(2-(ll)ans[i]*tmp[i]%P+P)*ans[i]%P;
        NTT(ans, -1, p);
        return ans.resize(n), ans;
    }

    inline void div(const vector<int> &a, const vector<int> &b, vector<int> &d, vector<int> &r){
        if(b.size()>a.size()) return (void)(d.clear(), r=a);
        vector<int> A=a, B=b, iB;
        int n=a.size(), m=b.size();
        reverse(A.begin(), A.end()), reverse(B.begin(), B.end());
        B.resize(n-m+1), iB=inv(B, n-m+1);
        d=multi(A, iB);
        d.resize(n-m+1), reverse(d.begin(), d.end());
        r=multi(b, d);
        for(int i=0; i<m-1; ++i) r[i]=(P+a[i]-r[i])%P;
        r.resize(m-1);
    }

    inline vector<int> der(const vector<int> &a){
        vector<int> ans;
        ans.resize(a.size()-1);
        for(unsigned i=1; i<a.size(); ++i) ans[i-1]=(ll)a[i]*i%P;
        return ans;
    }

    void evalinit(int l, int r, int t, const vector<int> &a){
        if(l==r) return p[t].clear(), p[t].push_back(P-a[l]), p[t].push_back(1);
        int mid=(l+r)/2, k=t<<1;
        evalinit(l, mid, k, a), evalinit(mid+1, r, k|1, a);
        p[t]=multi(p[k], p[k|1]);
    }

    inline void eval(int l, int r, int t, const vector<int> &f, const vector<int> &a){
        if(r-l+1<=512){
            for(int i=l; i<=r; ++i){
                int x=0, j=f.size(), a1=a[i], a2=(ll)a[i]*a[i]%P, a3=(ll)a[i]*a2%P, a4=(ll)a[i]*a3%P, a5=(ll)a[i]*a4%P, a6=(ll)a[i]*a5%P, a7=(ll)a[i]*a6%P, a8=(ll)a[i]*a7%P;
                while(j>=8)
                    x=((ll)x*a8+(ll)f[j-1]*a7+(ll)f[j-2]*a6+(ll)f[j-3]*a5+(ll)f[j-4]*a4+(ll)f[j-5]*a3+(ll)f[j-6]*a2+(ll)f[j-7]*a1+f[j-8])%P, j-=8;
                while(j--) x=((ll)x*a[i]+f[j])%P;
                ans.push_back(x);
            }
            return;
        }
        vector<int> tmp;
        div(f, p[t], tmp, tmp);
        eval(l, (l+r)/2, t<<1, tmp, a), eval((l+r)/2+1, r, t<<1|1, tmp, a);
    }
}

inline vector<int> eval(const vector<int> &f, const vector<int> &a, int flag=-1){
    if(flag==-1) p.resize(a.size()<<2), evalinit(0, a.size()-1, 1, a);
    ans.clear(), eval(0, a.size()-1, 1, f, a);
    return ans;
}

int main() {
    cin>>n>>m, x.resize(n), y.resize(m);
    for(int i=0; i<n; ++i) cin>>x[i];
    for(int i=0; i<m; ++i) cin>>y[i];
    y=eval(x,y);
    for(int i=0; i<m; ++i) cout<<y[i]<<' ';
    return 0;
}