【模板】多项式多点求值 题解

· · 题解

题目大意

给定 n-1 次多项式 f(x)m 个数 a_i ,求所有 f(a_i) 的值。

答案对 998244353 取模。

数据范围 1\leq n,m\leq 6.4\cdot 10^40\leq a_i,[x^i]f(x)\leq 998244353

前置芝士:

多项式除法(但商没用,其实就是取模)。(不用牛顿迭代,甚至非常入门的求导,积分都不要,老少咸宜!)

不会出门左拐即可。。(应该已经没有人读到这里就走了吧)

solution

多项式除法的话推荐俺的文章。(其实随便那篇都行,了解原理了就好了)

先来两个推论:

1.f(a_i)\equiv f(x)\ mod\ (x-a_i)

感觉大佬们的证法不太好理解。。(是我太菜了)

其实还可以这样:

当我们对两个式子同除 (x-a_i) 的时候。(都会写竖式吧

对于当前的被除数最高次项,在对齐了次数与系数后,减去了几个 x ,就会站起来几个 a_i

意思就是 f(x) 除以 (x-a_i) 的余数就可以化归成 f(a_i) 。(当然肯定还能再除下去,主要是证明推论的成立)

那么对于柿子(其实就是多项式除法里的柿子):

f(x)=q(x)\cdot g(x)+r(x)

如果我们构造 g(x)=(x-a_i) 最后 f(x)=r(x) ,注意膜一个一次项的多项式,最后就只会剩常数项(芜湖起飞。。)

2.(f(x)\ mod\ \varphi(x)\cdot g(x))\ mod\ g(x)=f(x)\ mod\ g(x)

把多项式替换成一个常数会好理解得多,模拟一下就好了。c 取模还赶不上对 n\cdot c 取模?

推论说完了(有没有感觉推论二有点迭代的意思?)

预处理出所有区间 [l,r] 所对应的 \prod_{i=l}^{r}{(x-a_i)} ,哇哦,乘积形式哟。

那么就可以考虑分治 FFT 了。

在递归的时候,让多项式直接对当前区间所对应的模数取模,然后继续向下传(根据推论二,这样做也不会有影响),最后到了叶子结点的时候就好了(根据推论一,最后膜完只剩常数项)

时间复杂度 O(n\log_{2}^{2}{n})

人傻常熟大 wtcl 。。

(转置定理什么的压根不知道/kel)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=128e3+10;
typedef ll arr[N<<1];
const int mod=998244353;
const int inv3=332748118;
int n,m,g[N],lim=1,fre,id[N<<1];
arr f[20],fr,gr,gi,q,ans,tp1,tp2,tmp,F1,G1;
int rt,tot,ls[N],rs[N],Ans[N];
vector<int> v[N];
inline ll inc(ll x,ll y){return x+y>=mod?x+y-mod:x+y;}
inline ll dec(ll x,ll y){return x-y<0?x-y+mod:x-y;}
inline ll mul(ll x,ll y){return 1ll*x*y%mod;}
inline ll read()
{
    ll s=0,w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
    return s*w;
}
inline int ksm(int a,int b)
{
    int tmp=1;
    while(b)
    {
        if(b&1)tmp=(1ll*tmp*a)%mod;
        b>>=1,a=(1ll*a*a)%mod;
    }
    return tmp;
}
inline void Never_Tell_TLE(ll* NTT,int sign)
{
    for(int i=0;i<=lim;++i)if(i<id[i])
    {
        ll NTt=NTT[i];
        NTT[i]=NTT[id[i]];
        NTT[id[i]]=NTt;
    }
    for(int mid=1;mid<lim;mid<<=1)
    {
        int Unit_root;
        if(sign==1)Unit_root=ksm(3,(mod-1)/(mid<<1));
        else Unit_root=ksm(inv3,(mod-1)/(mid<<1));
        for(int R=mid<<1,r=0;r<lim;r+=R)
        {
            int pw=1;
            for(int l=0;l<mid;++l,pw=(1ll*pw*Unit_root)%mod)
            {
                int butt=NTT[l+r],rfly=(1ll*pw*NTT[l+r+mid])%mod;
                NTT[l+r]=inc(butt,rfly);
                NTT[l+r+mid]=dec(butt,rfly);
            }
        }
    }
    if(sign==-1)
    {
        int inv_lim=ksm(lim,mod-2);
        for(int i=0;i<=lim;++i)NTT[i]=(1ll*NTT[i]*inv_lim)%mod;
    }
}
/*
此Inv非彼Inv,一定要注意
因为多项式是0~n/m而非0~(n-1)/(m-1)
调用的时候也要多加注意 
*/
inline void Inv(ll* F,ll* G,int nm)
{
    if(nm==1)
    {
        G[0]=ksm(F[0],mod-2);
        return;
    }
    Inv(F,G,nm>>1);
    lim=1,fre=0;
    for(;lim<=(nm<<1);lim<<=1)++fre;
    for(int i=0;i<=lim;++i)id[i]=(id[i>>1]>>1)|((i&1)<<(fre-1));
    for(int i=0;i<=nm;++i)ans[i]=F[i];
    for(int i=nm+1;i<=lim;++i)ans[i]=0;
    Never_Tell_TLE(ans,1),Never_Tell_TLE(G,1);
    for(int i=0;i<=lim;++i)G[i]=mul(inc(2ll,-mul(ans[i],G[i])+mod),G[i]);
    Never_Tell_TLE(G,-1);
    for(int i=nm;i<=lim;++i)G[i]=0;
}
inline void Mul(ll* F,ll* G,int nm)
{
    if(nm)
    {
        lim=1,fre=0;
        for(;lim<=(nm<<1);lim<<=1)++fre;
    }
    for(int i=0;i<=lim;++i)id[i]=(id[i>>1]>>1)|((i&1)<<(fre-1));
    for(int i=0;i<lim;++i)F1[i]=F[i],G1[i]=G[i];
    Never_Tell_TLE(F1,1),Never_Tell_TLE(G1,1);
    for(int i=0;i<lim;++i)F1[i]=mul(F1[i],G1[i]);
    Never_Tell_TLE(F1,-1);
    for(int i=0;i<lim;++i)F[i]=F1[i];
}
inline void Mod(ll* F,ll* G,ll* R,int n,int m,ll* FR,ll* GR,ll* GI,ll* Q)
{
    for(int i=0;i<=n;++i)FR[n-i]=F[i];
    for(int i=0;i<=m;++i)GR[m-i]=G[i];
    for(int i=n-m+1;i<=m;++i)GR[i]=0;
    int lm=1,fe=0;
    for(;lm<=(n-m);lm<<=1)++fe;
    Inv(GR,GI,lm);
    Mul(FR,GI,n);
    for(int i=0;i<=n-m;++i)Q[i]=FR[n-m-i];
    for(int i=n-m+1;i<=lim;++i)Q[i]=0;
    Mul(Q,G,n);
    for(int i=0;i<m;++i)R[i]=inc(F[i],-Q[i]+mod);
    for(int i=0;i<=lim;++i)FR[i]=GR[i]=GI[i]=Q[i]=0;
}
inline void tree_form(int &p,int l,int r)
{
    p=++tot;lim=1,fre=0;
    for(;lim<=(r-l+1);lim<<=1)++fre;
    v[p].resize(lim);
    if(l==r)
    {
        v[p][0]=dec(mod,g[l]),v[p][1]=1;
        return;
    }
    int mid=(l+r)>>1,cnt;
    tree_form(ls[p],l,mid),tree_form(rs[p],mid+1,r);
    lim=1,fre=0;
    for(;lim<=(r-l+1);lim<<=1)++fre;
    cnt=v[ls[p]].size();
    for(int i=0;i<cnt;++i)tp1[i]=v[ls[p]][i];
    for(int i=cnt;i<=lim;++i)tp1[i]=0;
    cnt=v[rs[p]].size();
    for(int i=0;i<cnt;++i)tp2[i]=v[rs[p]][i];
    for(int i=cnt;i<=lim;++i)tp2[i]=0;
    Mul(tp1,tp2,0);
    for(int i=0;i<lim;++i)v[p][i]=tp1[i];
}
inline void solve(int p,int l,int r,int dep,int siz)
{
    for(int i=0;i<=r-l+1;++i)tmp[i]=v[p][i];
    if(siz<r-l+1)for(int i=0;i<=r-l;++i)f[dep][i]=f[dep-1][i];
    else Mod(f[dep-1],tmp,f[dep],siz,r-l+1,fr,gr,gi,q);
    for(int i=0;i<=r-l+1;++i)tmp[i]=0;
    if(l==r)
    {
        Ans[l]=f[dep][0];
        return;
    }
    int mid=(l+r)>>1;
    solve(ls[p],l,mid,dep+1,r-l);
    solve(rs[p],mid+1,r,dep+1,r-l);
}
int main()
{
    n=read(),m=read();
    for(int i=0;i<=n;++i)f[0][i]=read();
    for(int i=1;i<=m;++i)g[i]=read();
    tree_form(rt,1,m);
    solve(rt,1,m,1,n);
    for(int i=1;i<=m;++i)printf("%d\n",Ans[i]);
    return 0;
}

只求能帮助到几个人罢。。