题解 P5641 【【CSGRound2】开拓者的卓识】

· · 题解

题解

当我们求sum_{k,1,r}时,我们考虑每一个a[i](i\in[1,r])的贡献

通过我们对题目公式描述的理解,可以想象为求k+1个二元组(l,r),使l_i<=l_{i+1},r_i>=r_{i+1},i\in[1,k]这是一个很显然的组合数问题。

所以sum_{k,1,r}=\sum_{i=1}^{r}a[i] * C(i+k-2,k-1) * C(r-i+k-1,k-1)

于是我们就有了一个O(n^2)的算法

Code:

#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
#define maxn 1005
using namespace std;
ll power(ll n,ll x){
    ll sp=1LL;
    while(x){
        if(x&1LL) sp=(sp*n)%mod;
        n=(n*n)%mod;
        x>>=1LL;
    }return sp;
}
ll fact[maxn<<1],inv[maxn<<1];
void init(){
    fact[0]=1LL;
    for(int i=1;i<=2000;i++)
    fact[i]=fact[i-1]*i%mod;
    inv[2000]=power(fact[2000],mod-2);
    for(int i=1999;i>=0;i--)
    inv[i]=inv[i+1]*(i+1)%mod;  
}
ll C(ll n,ll m){
    return fact[n]*inv[n-m]%mod*inv[m]%mod;
}
int n;
ll k,a[maxn];
int main(){
    scanf("%d%lld",&n,&k);
    for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
    init();
    for(int i=1;i<=n;i++){
        ll res=0;
        for(int j=1;j<=i;j++)
        res=(res+C(j+k-2,k-1)*C(i-j+k-1,k-1)%mod*a[j]%mod)%mod;
        printf("%lld ",res);
    }return 0;
}

考虑继续优化这个式子

我们可以发现,在一个sum_{k,1,r}中,(i+k-2)+(r-i+k-1)=r+2\times k - 3

而这个是一个常量,就可以把它变成一个卷积的式子

A_i=C(i+k-1,k-1)\times a_{i+1},B_i=C(i+k-1,k-1)(i\in[0,n))

sum_{k,1,r}=\sum_{i=0}^{r-1}A_i * B_{r-i-1} (\mod 998244353)

因为要模998244353,所以可以NTT解决,换个模数就是MTT了

Code:

#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
#define maxn 100005
using namespace std;
ll power(ll n,ll x){
    ll sp=1LL;
    while(x){
        if(x&1LL) sp=(sp*n)%mod;
        n=(n*n)%mod;
        x>>=1LL;
    }return sp;
}
const ll g=3,gi=332748118;
int r[maxn*3],limit,L;
inline void NTT(ll *a,int type){
    for(int i=0;i<limit;i++) if(i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<limit;mid<<=1){
        ll Wn=power(type==1?g:gi,(mod-1)/(mid<<1));
        for(int j=0;j<limit;j+=(mid<<1)){
            ll w=1;
            for(int k=0;k<mid;k++,w=w*Wn%mod){
                ll x=a[j+k],y=w*a[j+k+mid]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+mid]=(x-y+mod)%mod;
            }
        }
    }
}
int n;
ll k,a[maxn*3],b[maxn*3];
ll in[maxn];
int main(){
    scanf("%d%lld",&n,&k);
    in[0]=in[1]=1;
    for(int i=2;i<n;i++) in[i]=(mod-mod/i)*in[mod%i]%mod;//线性求逆元
    b[0]=1;
    for(int i=1;i<n;i++) b[i]=b[i-1]*(i+k-1)%mod*in[i]%mod;//因为k太大了,无法预处理阶乘,需要递推
    for(int i=0;i<n;i++)
    scanf("%lld",&a[i]),a[i]=(a[i]*b[i])%mod;//处理Ai,Bi
   //开始NTT
    limit=1;
    while(limit<=n*2) limit<<=1,L++;
    for(int i=0;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
    NTT(a,1);NTT(b,1);
    for(int i=0;i<limit;i++) a[i]=(a[i]*b[i])%mod;
    NTT(a,-1);
    ll inv=power(limit,mod-2);
    for(int i=0;i<n;i++) printf("%lld ",a[i]*inv%mod);
    return 0;
}