题解 P5158 【【模板】多项式快速插值】
rEdWhitE_uMbrElla · · 题解
咕咕~ 本蒟蒻是看到新模板后现学的,,,于是可能有描述不准确的地方,请见谅。
0. 多项式多点求值
这是做着题的基本,详见P5050,蒟蒻在这里就不详细讲了。
1. 假·解法
-
\prod_{j\ne i}(x_i-x_j) 这是要先计算的。
设
L(x)=\prod_{i=0}^{n-1} (x-x_i),R_i(x)=\frac{L(x)}{x-x_i} ,则有\prod_{j\ne i}(x_i-x_j)=R_i(x_i) ,我们发现此时分式R_i 分子分母都为0,于是有\prod_{j\ne i}(x_i-x_j)=R_i(x_i)=L'(x_i) ^{\tiny{\textcolor{grey}{\text{*根据洛必达法则}}}} 于是此部分可用多项式多点求值求得。
-
整体
上面那部分求完了,就可以直接求整体了。
经上一部分的转化,我们现在要求
f(x)=\sum_{i=0}^{n-1}\frac{y_i}{\prod_{j\ne i}(x_i-x_j)}\prod_{j\ne i}(x-x_j) 然后我们可以NTT一下,于是
超长公式警告f(x)=\Large[\normalsize\prod_{i=\lfloor\frac{n}{2}\rfloor+1}^{n-1}(x-x_i)\Large]\{\normalsize\sum_{i=0}^{\lfloor\frac{n}{2}\rfloor}\Large[\normalsize\frac{y_i}{\prod_{j\ne i}(x_i-x_j)}\Large][\normalsize\prod_{0\le j\le \lfloor\frac{n}{2}\rfloor,j\ne i}(x-x_j)\Large]\}\normalsize+\Large[\normalsize\prod_{i=0}^{\lfloor\frac{n}{2}\rfloor}(x-x_i)\Large]\normalsize\Large[\normalsize\sum_{i=\lfloor\frac{n}{2}\rfloor+1}^{n-1}\quad \prod_{\lfloor\frac{n}{2}\rfloor<j<n,j\ne i}(x-x_j)\Large]\normalsize 个人觉得这个公式被强制性加括号后好读了许多。。。
总时间复杂度
代码如下(封装度可高了QAQ):
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int P = 998244353;
int n, m;
vector<int> x, y, z, ans;
vector<vector<int>> p;
namespace commonTool{
inline int power(ll x, int y=P-2){
int ans=1;
for(; y; y>>=1, x=x*x%P) if(y&1) ans=ans*x%P;
return ans;
}
inline int power2to(int x){
int n=1;
while(n<=x) n<<=1;
return n;
}
inline int mod(int x){
return x<P?x:x-P;
}
inline void NTT(vector<int> &f, int g, int n){
using namespace commonTool;
f.resize(n);
for(int i=0, j=0; i<n; ++i){
if(i>j) swap(f[i], f[j]);
for(int k=n>>1; (j^=k)<k; k>>=1);
}
vector<int> w(n>>1);
for(int i=1; i<n; i<<=1){
for(int j=w[0]=1, w0=(g==1?power(3, (P-1)/i/2):power(power(3, (P-1)/i/2))); j<i; ++j) w[j]=(ll)w[j-1]*w0%P;
for(int j=0; j<n; j+=i<<1){
for(int k=j; k<j+i; ++k){
int t=(ll)f[k+i]*w[k-j]%P;
f[k+i]=mod(f[k]-t+P);
f[k]=mod(f[k]+t);
}
}
}
if(g==-1) for(int i=0, I=power(n); i<n; ++i) f[i]=(ll)f[i]*I%P;
}
}
namespace sptInter{
inline vector<int> add(const vector<int> &f, const vector<int> &g){
vector<int> ans=f;
for(unsigned i=0; i<f.size(); ++i) (ans[i]+=g[i])%=P;
return ans;
}
inline vector<int> multi(const vector<int> &f, const vector<int> &g){
using namespace commonTool;
vector<int> F=f, G=g;
int p=power2to(f.size()+g.size()-2);
NTT(F, 1, p), NTT(G, 1, p);
for(int i=0; i<p; ++i) F[i]=(ll)F[i]*G[i]%P;
NTT(F, -1, p);
return F.resize(f.size()+g.size()-1), F;
}
inline vector<int> inv(const vector<int> &f, int n=-1){
using namespace commonTool;
if(n==-1) n=f.size();
vector<int> ans;
if(n==1) return ans.push_back(power(f[0])), ans;
ans=inv(f, (n+1)/2);
vector<int> tmp(&f[0], &f[0]+n);
int p=power2to(n*2-2);
NTT(tmp, 1, p), NTT(ans, 1, p);
for(int i=0; i<p; ++i) ans[i]=(2-(ll)ans[i]*tmp[i]%P+P)*ans[i]%P;
NTT(ans, -1, p);
return ans.resize(n), ans;
}
inline void div(const vector<int> &a, const vector<int> &b, vector<int> &d, vector<int> &r){
if(b.size()>a.size()) return (void)(d.clear(), r=a);
vector<int> A=a, B=b, iB;
int n=a.size(), m=b.size();
reverse(A.begin(), A.end()), reverse(B.begin(), B.end());
B.resize(n-m+1), iB=inv(B, n-m+1);
d=multi(A, iB);
d.resize(n-m+1), reverse(d.begin(), d.end());
r=multi(b, d);
for(int i=0; i<m-1; ++i) r[i]=(P+a[i]-r[i])%P;
r.resize(m-1);
}
inline vector<int> der(const vector<int> &a){
vector<int> ans;
ans.resize(a.size()-1);
for(unsigned i=1; i<a.size(); ++i) ans[i-1]=(ll)a[i]*i%P;
return ans;
}
void evalinit(int l, int r, int t, const vector<int> &a){
if(l==r) return p[t].clear(), p[t].push_back(P-a[l]), p[t].push_back(1);
int mid=(l+r)/2, k=t<<1;
evalinit(l, mid, k, a), evalinit(mid+1, r, k|1, a);
p[t]=multi(p[k], p[k|1]);
}
inline void eval(int l, int r, int t, const vector<int> &f, const vector<int> &a){
if(r-l+1<=512){
for(int i=l; i<=r; ++i){
int x=0, j=f.size(), a1=a[i], a2=(ll)a[i]*a[i]%P, a3=(ll)a[i]*a2%P, a4=(ll)a[i]*a3%P, a5=(ll)a[i]*a4%P, a6=(ll)a[i]*a5%P, a7=(ll)a[i]*a6%P, a8=(ll)a[i]*a7%P;
while(j>=8)
x=((ll)x*a8+(ll)f[j-1]*a7+(ll)f[j-2]*a6+(ll)f[j-3]*a5+(ll)f[j-4]*a4+(ll)f[j-5]*a3+(ll)f[j-6]*a2+(ll)f[j-7]*a1+f[j-8])%P, j-=8;
while(j--) x=((ll)x*a[i]+f[j])%P;
ans.push_back(x);
}
return;
}
vector<int> tmp;
div(f, p[t], tmp, tmp);
eval(l, (l+r)/2, t<<1, tmp, a), eval((l+r)/2+1, r, t<<1|1, tmp, a);
}
}
inline vector<int> eval(const vector<int> &f, const vector<int> &a, int flag=-1){
if(flag==-1) p.resize(a.size()<<2), evalinit(0, a.size()-1, 1, a);
ans.clear(), eval(0, a.size()-1, 1, f, a);
return ans;
}
int main() {
cin>>n>>m, x.resize(n), y.resize(m);
for(int i=0; i<n; ++i) cin>>x[i];
for(int i=0; i<m; ++i) cin>>y[i];
y=eval(x,y);
for(int i=0; i<m; ++i) cout<<y[i]<<' ';
return 0;
}