题解 P5393 【下降幂多项式转普通多项式】

· · 题解

怎么清一色都先想到多项式多点求值/快速插值呀,这不是直接裸的分治 NTT 吗(不解)。
首先,我们要求的是:\displaystyle\sum^{n-1}_{i=0} a_ix^{\underline{i}},其中 x^{\underline{i}}=\displaystyle\prod_{j=0}^{i-1}x-j\
g_{l,r}(x)=\displaystyle\prod_{i=l}^{r-1}x-i,f_{l,r}(x)=\sum_{i=l}^{r-1}a_ig_{l,i}(x),显然答案为 f_{0,n}(x)\
mid\in(l,r),则 g_{l,r}(x)=g_{l,mid}(x)g_{mid,r}(x),f_{l,r}(x)=f_{l,mid}(x)+f_{mid,r}(x)g_{l,mid}(x)\

最后,g_{i,i+1}=x-i,f_{i,i+1}(x)=a_i,g_{i,i}(x)=1,f_{i,i}(x)=0\

直接分治即可,复杂度 \Theta(n\log^2n)

还有一个小优化:f_{l,r}(x),g_{l,r}(x)r-l 次的,在 \pmod {x^{2^k}} 意义下会刚好多一项。
考虑到循环卷积的性质,多出来的这一项单独处理,会大大降低常数。

跑得挺快的代码(3.92s/10.57MB/3.55KB )

//this code is made by warzone
//2021.11.9 16:00
#include<stdio.h>
#include<string.h>
typedef unsigned long long ull;
typedef unsigned int word;
const word mod=998244353,nimod=29;
constexpr ull pow(word a,word b){//快速幂
    word ans=1;
    for(;b;b>>=1){
        if(b&1) ans=1ull*ans*a%mod;
        a=1ull*a*a%mod;
    }
    return ans;
}
#define mx 18
#define mx_ 17
word root[1<<mx],inv[1<<mx],realid[1<<(mx+1)];
word ni[nimod];
struct READ{
    char c;
    inline READ(){//NTT 预处理
        const ull num1=pow(3,(mod-1)>>mx);
        const ull num2=pow(num1,mod-2);
        root[1<<mx_]=inv[1<<mx_]=ni[1]=1;
        ni[2]=pow(2,mod-2),realid[1]=0;
        for(word floor=1,i=2;floor<=mx;++floor){
            ni[(1u<<floor)%nimod]=1ull*ni[(1u<<(floor-1))%nimod]*ni[2]%mod;
            for(;i<1u<<(floor+1);++i)
                realid[i]=(i&1)<<(floor-1)|realid[i>>1];
        }
        for(word i=(1<<mx_)+1;i<1<<mx;++i){
            root[i]=num1*root[i-1]%mod;
            inv[i]=num2*inv[i-1]%mod;
        }
        for(word i=(1<<mx_)-1;i;--i)
            root[i]=root[i<<1],inv[i]=inv[i<<1];
        c=getchar();}
    template<typename type>
    inline READ& operator >>(type& num){//快读
        while('0'>c||c>'9') c=getchar();
        for(num=0;'0'<=c&&c<='9';c=getchar())
            num=num*10+(c-'0');
        return *this;
    }
}cin;
#define nttfor(size)                                \
    for(word floor=1;floor<(size);floor<<=1)        \
        for(word head=0;head<(size);head+=floor<<1) \
            for(word i=0;i<floor;++i)//FFT 循环
#define ntt(num,root) [&](){                            \
    word num1=num[head|i],num2=num[head|i|floor];       \
    num1+=(num2=1ull*num2*root[i|floor]%mod);           \
    num[head|i]=num1>=mod? num1-mod:num1;               \
    num1-=num2,num1+=mod-num2;                          \
    num[head|i|floor]=num1>=mod? num1-mod:num1;}()//蝴蝶变换
#define id(size,i) realid[(i)|(size)]//二进制翻转的下标
#define FOR(size) for(word i=0;i<(size);++i)

word eax[1<<mx],ebx[1<<mx],ecx[1<<mx],edx[1<<mx];
word ans[1<<mx],get[1<<mx];
struct poly{//多项式(包含单独处理的位置)
    word *const num,end;
    inline poly(word *const in):
        num(in),end(0){};
    inline poly(const poly &p):
        num(p.num),end(p.end){}
};
inline void cdq(poly& ans,poly& get,const word size,const word id){
    if(size==1){
        ans.end=0,get.end=1;    //f_{i,i+1}(x)=a_i
        get.num[0]=id? mod-id:0;//g_{i,i+1}(x)=x-i
        return;
    }
    poly ans0=ans.num,ans1=ans.num+(size>>1);//f_{l,mid}(x),f_{mid,r}(x)
    poly get0=get.num,get1=get.num+(size>>1);//g_{l,mid}(x),g_{mid,r}(x)
    cdq(ans0,get0,size>>1,id<<1);
    cdq(ans1,get1,size>>1,id<<1|1);
    memcpy(edx,ans0.num,size<<1);
    FOR(size>>1){
        word head=id(size,i);
        eax[head]=ans1.num[i];
        ebx[head]=get0.num[i];
        ecx[head]=get1.num[i];
        head=id(size,size>>1|i);
        eax[head]=ebx[head]=ecx[head]=0;
    }
    const word head=id(size,size>>1);
    eax[head]=ans1.end;
    ebx[head]=get0.end;
    ecx[head]=get1.end;
    nttfor(size) ntt(eax,root),ntt(ebx,root),ntt(ecx,root);
    FOR(size){
        const word head=id(size,i);
        ans.num[head]=1ull*eax[i]*ebx[i]%mod;//f_{l,r}(x)=f_{l,mid}(x)+f_{mid,r}(x)g_{l,mid}(x)
        get.num[head]=1ull*ebx[i]*ecx[i]%mod;//g_{l,r}(x)=g_{l,mid}(x)g_{mid,r}(x)
    }
    nttfor(size) ntt(ans.num,inv),ntt(get.num,inv);
    word num1=ni[size%nimod],num2;
    FOR(size){
        ans.num[i]=1ull*num1*ans.num[i]%mod;
        get.num[i]=1ull*num1*get.num[i]%mod;
    }
    ans.end=1ull*ans1.end*get0.end%mod;
    get.end=1ull*get0.end*get1.end%mod;
    num1=ans.end? mod-ans.end:0;
    num2=get.end? mod-get.end:0;
    ans.num[0]=(num1+ans.num[0])%mod;
    get.num[0]=(num2+get.num[0])%mod;//单独处理末项
    FOR(size>>1) if((ans.num[i]+=edx[i])>=mod)
        ans.num[i]-=mod;//+f_{l,mid}(x)
    if((ans.num[size>>1]+=ans0.end)>=mod)
        ans.num[size>>1]-=mod;
}
int main(){
    word n,size;
    cin>>n;
    for(word i=0;i<n;++i) cin>>ans[i];
    for(size=1;size<n;size<<=1);
    poly p0=ans,p1=get;
    cdq(p0,p1,size,0);
    for(word i=0;i<n;++i) printf("%u ",ans[i]);
    return 0;
}