题解 P5050 【【模板】多项式多点求值】
Newuser
2018-12-12 22:15:11
问题:给定一个n次多项式f(x),现在请你对于i∈[1,m],求出f(ai)。
我们考虑分治处理,对于i∈[1,m/2] , 我们构造一个g(x) = ∏(x-ai) (1<=i<=n/2),我们发现对于[1,m/2]中所有的x值代入这个函数值都为0,那么我们直接对于[1,m/2]的原函数的f都%上g(x)应该是不会影响(即消除掉若干个g(x)不会影响)。这样就可以同时达到消除大于等于x^(m/2)次项的目的,我们发现消除到最后剩下的那个常数项的值恰好就是我们要求的f(x)(因为除了常数项都被我们消除掉了), 时间复杂度为O(nlog^2n)
我们写成一个类似线段树的结构把插值的多项式存下来。
[欢迎来Newuser小站玩owo](http://www.newuser.top/2018/12/04/%E3%80%90moban%E3%80%91%E5%A4%9A%E9%A1%B9%E5%BC%8F%E7%AE%97%E6%B3%95%E6%A8%A1%E6%9D%BF/)
```cpp
#include<stdio.h>
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<vector>
using namespace std;
const int maxn = 64005;
const int mod = 998244353;
const int g = 3;
int mul(int x,int y) { return 1ll*x*y%mod; }
int add(int x,int y) { x+=y; return x>=mod?x-mod:x; }
int sub(int x,int y) { x-=y; return x<0?x+mod:x; }
int ksm(int a,int b) {
int ans = 1;
for(;b;b>>=1,a=mul(a,a))
if(b&1) ans = mul(ans,a);
return ans;
}
void ntt(int *a,int deg,int dft) {
for(int i=0,j=0;i<deg;i++) {
if(i<j) swap(a[i],a[j]);
for(int k=(deg>>1);(j^=k)<k;k>>=1);
}
for(int st=1;st<deg;st<<=1) {
int dwg = (dft==1) ? ksm(g,(mod-1)/(st<<1)) : ksm(g,mod-1-(mod-1)/(st<<1));
for(int i=0;i<deg;i+=(st<<1)) {
int ng = 1;
for(int j=i;j<i+st;j++) {
int x = a[j]; int y = mul(ng,a[j+st]);
ng = mul(ng,dwg);
a[j] = add(x,y); a[j+st] = add(x,mod-y);
}
}
}
if(dft==1) return;
int invs = ksm(deg,mod-2);
for(int i=0;i<deg;i++) a[i] = mul(invs,a[i]);
}
int F[18][maxn*4];
int n,m;
int ta[maxn*4],tb[maxn*4];
void MULL(int *a,int *b,int *c,int le) {
for(int i=0;i<le;i++) ta[i]=b[i];
for(int j=0;j<le;j++) tb[j]=c[j];
ntt(ta,le,1); ntt(tb,le,1);
for(int i=0;i<le;i++) ta[i]=mul(ta[i],tb[i]);
ntt(ta,le,-1);
for(int i=0;i<le;i++) a[i] = ta[i];
}
void ginv(int deg,int *a,int *b) {
if(deg==1) { b[0]=ksm(a[0],mod-2); return; }
ginv(deg>>1,a,b);
for(int i=0;i<deg;i++) ta[i]=a[i];
for(int i=deg;i<(deg<<1);i++) ta[i]=0;
ntt(ta,deg<<1,1); ntt(b,deg<<1,1);
for(int i=0;i<(deg<<1);i++) b[i]=mul(b[i],sub(2,mul(ta[i],b[i])));
ntt(b,(deg<<1),-1);
for(int i=deg;i<(deg<<1);i++) b[i]=0;
}
int Q[maxn*4],Fr[maxn*4],Gr[maxn*4],Gri[maxn*4],tmp[maxn*4];
void gmod(int *O,int *F,int *G,int n,int m) {//F%G==O (F n G m)
for(int i=0;i<=n;i++) Fr[i] = F[n-i];
for(int i=0;i<=m;i++) Gr[i] = G[m-i];
for(int i=n-m+2;i<=m;i++) Gr[i] = 0;
int le = 1; for(;le<=(n-m);le<<=1);
for(int i=n-m+1;i<=m;i++) Gr[i]=0;
ginv(le,Gr,Gri);
le = 1; for(;le<=(n<<1);le<<=1);
MULL(Q,Gri,Fr,le);
reverse(Q,Q+n-m+1);
for(int i=n-m+1;i<le;i++) Q[i]=0;
MULL(tmp,Q,G,le);
for(int i=0;i<m;i++) O[i] = sub(F[i],tmp[i]);
for(int i=0;i<le;i++) tmp[i] = Fr[i] = Gr[i] = Gri[i] = Q[i] = 0;
}
vector<int>ve[maxn*2];
int tot,rt,ls[maxn*2],rs[maxn*2],X[maxn];
int tpa[maxn*4],tpb[maxn*4],tpc[maxn*4];
void maketree(int &p,int l,int r) {
p = ++tot;
int len = 1; for(;len<r-l+2;len<<=1);
ve[p].resize(len); // len ci
if(l==r) {
ve[p][0] = (mod-(X[l]%mod+mod)%mod); ve[p][1] = 1; return;
}
int mid = (l+r)>>1;
maketree(ls[p],l,mid); maketree(rs[p],mid+1,r);
int ss = ve[ls[p]].size();
for(int i=0;i<ss;i++) tpa[i] = ve[ls[p]][i]; for(int i=ss;i<len;i++) tpa[i]=0;
ss = ve[rs[p]].size();
for(int i=0;i<ss;i++) tpb[i] = ve[rs[p]][i]; for(int i=ss;i<len;i++) tpb[i]=0;
MULL(tpc,tpa,tpb,len);
for(int i=0;i<len;i++) ve[p][i] = tpc[i];
}
int G[maxn*4];
int ANS[maxn];
void DC(int dep,int p,int l,int r,int cs) {
int m = r-l+1;
for(int i=0;i<=m;i++) G[i] = ve[p][i];
if(cs>=m) gmod(F[dep],F[dep-1],G,cs,m);
else {
for(int i=0;i<=m-1;i++) F[dep][i] = F[dep-1][i];
}
for(int i=0;i<=m;i++) G[i] = 0;
if(l==r) {
ANS[l] = F[dep][0];
return;
}
int mid = (l+r)>>1;
DC(dep+1,ls[p],l,mid,m-1);
DC(dep+1,rs[p],mid+1,r,m-1);
}
int main() {
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) {
int x; scanf("%d",&x);
F[0][i] = x;
}
for(int i=1;i<=m;i++) scanf("%d",&X[i]);
maketree(rt,1,m);
for(int i=0;i<=m;i++) G[i] = ve[rt][i];
DC(1,rt,1,m,n);
for(int i=1;i<=m;i++) printf("%d\n",ANS[i]);
}
```