题解 P5050 【【模板】多项式多点求值】

· · 题解

记待求值的多项式为 F(x),需要求出其在 x_0 处的点值 F(x_0)

如果 F(x)=Q(x)G(x)+R(x),则有 F(x_0)=Q(x_0)G(x_0)+R(x_0)。 如果构造 G(x) 使得 G(x_0)=0,那么就有 F(x_0)=R(x_0)

那么构造 G(x)=x-x_0,则 F(x)G(x) 做多项式取模后得到的常数项就是 F(x_0)(因为取模后只有这一项了)。

多项式取模的过程可以分治进行,将 F(x)\prod_{i=1}^{m/2} (x-x_i)\prod_{i=m/2+1}^{m} (x-x_i) 分别取模,再将取模后的结果向两边递归下去。每次 O(mlogm) 的多项式取模都可以使得原问题的规模减半,就可以在 O(nlogn+mlog^2m) 的时间内求出所有的点值(最开始还需要将 F(x)\prod_{i=1}^{m}(x-x_i) 取模)。

这是大部分人的做法,但是鉴于多项式取模的大常数,这个做法的效率并不高。

重新分析一下,R(x) 只有常数项,如果能算出 Q(x) 的常数项,那么就可以计算出 R(x)

考虑一下多项式除法是进行了一个怎么样的过程,假设 F(x),G(x) 的最高次项次数分别为 n,m

F(x)=Q(x)G(x)+R(x) x^nF(\frac{1}{x})=x^{n-m}Q(\frac{1}{x})x^mG(\frac{1}{x})+x^nR(\frac{1}{x}) F_r(x)=Q_r(x)G_r(x)+x^{n-m+1}R_r(x) Q_r(x)=F_r(x)G_r^{-1}(x)\pmod{x^{n-m+1}}

在这里,F_r(x) 表示的是 F_r(x) 的系数翻转后的结果。

G(x)=\prod_{i=1}^{m}(x-x_i),G_0(x)=\prod_{i=1}^{m/2}(x-x_i),G_1(x)=\prod_{i=m/2+1}^{m}(x-x_i)

可以想到先求出 Q_r(x)=F_r(x)G_r^{-1}(x) 然后用利用 G_{0r}^{-1}(x)=G_r^{-1}(x)G_1(x) 来向下递归。

最底层的 Q(x)n 项,也就是说为了保证底层计算结果是正确的,所有的计算都必须至少在 \pmod{x^n} 的意义下进行,这样的复杂度显然是无法接受的。

但是注意到根本不需要计算出 Q(x),只需要计算出 Q(x) 的常数项,或者说 [x^{n-1}]Q_r(x)。换言之,只有某一项可能对最高项产生贡献才需要保留,否则可以直接舍弃。

如果接下来乘上的若干 G(x) 之积的次数界为 k,那么就只需要保留 Q(x) 的最后 k 项,因为前面的若干项不可能对最高项产生贡献。

即,递归到区间 [l,r]Q(x) 只需要保留最高的 r-l+1 项。

这样,每次只需要一次多项式乘法而不是多项式取模就可以使得问题规模减半,效率得到了显著的提升(对我而言,五倍)。

此外,还需要注意:最底层需要第 n 项,但是递归计算时当成只有 m 项来处理,这样会导致 n>m 时答案会错。

这有两个方法解决,第一种就是将 F(x)G(x) 取模。第二种更简单,令 m=\max(m,n),多出来的一些询问不会对答案造成影响。

(一千个人就有一千个多项式板子,代码仅供参考)

#include<bits/stdc++.h>
using namespace std;
#define I inline int
#define V inline void
#define ll long long int
#define isnum(ch) ('0'<=ch&&ch<='9')
#define FOR(i,a,b) for(int i=a;i<=b;i++)
#define ROF(i,a,b) for(int i=a;i>=b;i--)
#define gc (_op==_ed&&(_ed=(_op=_buf)+fread(_buf,1,100000,stdin),_op==_ed)?EOF:*_op++)
char _buf[100000],*_op(_buf),*_ed(_buf);
const int N=1<<18|1,mod=998244353;
V check(int&x){x-=mod,x+=x>>31&mod;}
I getint(){
    int _s=0;char _ch=gc;
    while(!isnum(_ch))_ch=gc;
    while(isnum(_ch))_s=_s*10+_ch-48,_ch=gc;
    return _s;
}
I Pow(ll t,int x){
    ll s=1;
    for(;x;x>>=1,t=t*t%mod)if(x&1)s=s*t%mod;
    return s;
}
namespace poly{
    int lmt,w[N],r[N];
    V init(int n){
        int l=-1,wn;
        for(lmt=1;lmt<=n;)lmt<<=1,l++;
        FOR(i,0,lmt-1)r[i]=(r[i>>1]>>1)|((i&1)<<l);
        wn=Pow(3,mod>>++l),w[lmt>>1]=1;
        FOR(i,(lmt>>1)+1,lmt-1)w[i]=1ll*w[i-1]*wn%mod;
        ROF(i,(lmt>>1)-1,1)w[i]=w[i<<1];
    }
    V cl(int*a,int n){memset(a,0,n<<2);}
    I getLen(int n){return 1<<32-__builtin_clz(n);}
    V mul(int*a,int x,int n,int*b){while(n--)*b++=1ll**a++*x%mod;}
    V dot(int*a,int*b,int n,int*c){while(n--)*c++=1ll**a++**b++%mod;}
    V DFT(int*a,int l){
        static unsigned ll tmp[N];
        int u=__builtin_ctz(lmt/l),t;
        FOR(i,0,l-1)tmp[i]=a[r[i]>>u];
        for(int i=1;i^l;i<<=1)for(int j=0,d=i<<1;j^l;j+=d)FOR(k,0,i-1)
            t=tmp[i|j|k]*w[i|k]%mod,tmp[i|j|k]=tmp[j|k]+mod-t,tmp[j|k]+=t;
        FOR(i,0,l-1)a[i]=tmp[i]%mod;
    }
    V IDFT(int*a,int l){reverse(a+1,a+l),DFT(a,l),mul(a,mod-mod/l,l,a);}
    V Inv(const int*a,int n,int*b){
        static int A[N],B[N],tmp[N],d,l;
        tmp[0]=Pow(a[0],mod-2),cl(A,d),cl(B,d);
        for(d=1,l=2;d<n;d<<=1,l<<=1){
            copy(a,a+min(l,n),A),copy(tmp,tmp+d,B);
            DFT(A,l),DFT(B,l),dot(A,B,l,A),IDFT(A,l);
            cl(A,d),DFT(A,l),dot(A,B,l,A),IDFT(A,l);
            copy(A+d,A+l,tmp+d),mul(tmp+d,mod-1,d,tmp+d);
        }
        copy(tmp,tmp+n,b);
    }
    int*f[N],*g[N],bin[N<<5],*np(bin);
    V Mul(int*a,int*b,int n,int m,int*c){
        static int A[N],B[N],l;
        l=getLen(n+m-1),copy(a,a+n,A),copy(b,b+m,B);
        DFT(A,l),DFT(B,l),dot(A,B,l,A),IDFT(A,l);
        copy(A,A+n+m-1,c),cl(A,l),cl(B,l);
    }
    V eva_init(int p,int l,int r,int*a){
        g[p]=np,np+=r-l+2,f[p]=np,np+=r-l+2;
        if(l==r)return g[p][0]=1,check(g[p][1]=mod-a[l]);
        int lc=p<<1,rc=lc|1,mid=l+r>>1,len1=mid-l+2,len2=r-mid+1;
        eva_init(lc,l,mid,a),eva_init(rc,mid+1,r,a);
        Mul(g[lc],g[rc],len1,len2,g[p]);
    }
    V Mult(int*a,int*b,int n,int m,int*c){
        static int A[N],B[N],l;
        l=getLen(n),copy(a,a+n,A),reverse_copy(b,b+m,B);
        DFT(A,l),DFT(B,l),dot(A,B,l,A),IDFT(A,l);
        copy(A+m-1,A+n,c);
        cl(A,l),cl(B,l);
    }
    V eva_work(int p,int l,int r,int*a){
        if(l==r)return void(a[l]=f[p][0]);
        int lc=p<<1,rc=lc|1,mid=l+r>>1,len1=mid-l+2,len2=r-mid+1;
        Mult(f[p],g[rc],r-l+1,len2,f[lc]);
        eva_work(lc,l,mid,a);
        Mult(f[p],g[lc],r-l+1,len1,f[rc]);
        eva_work(rc,mid+1,r,a);
    }
    V eva(int*a,int*b,int n,int m,int*c){
        static int X[N],Y[N],l;
        eva_init(1,1,m,b),Inv(g[1],m+1,X);
        reverse(X,X+m+1),Mul(a,X,n,m+1,Y);
        copy(Y+n,Y+n+m,f[1]),eva_work(1,1,m,c);
        FOR(i,1,m)check(c[i]=1ll*c[i]*b[i]%mod+a[0]);
    }
}
int n,m,a[N],b[N],c[N];
int main(){
    n=getint()+1,m=getint(),poly::init(max(n,m+1)<<1);
    FOR(i,0,n-1)a[i]=getint();
    FOR(i,1,m)b[i]=getint();
    n=max(n,m+1),poly::eva(a,b,n,max(n-1,m),c);
    FOR(i,1,m)cout<<c[i]<<'\n';
    return 0;
}

P.S. 代码里还使用了一个优化,就是从上往下计算 Q 的时候,会有长度为 l 的多项式乘上长度为 \frac{l}{2} 的多项式然后只取后 \frac{l}{2} 项的操作,这部分可以利用 FFT 循环卷积的性质只进行长度为 lDFT