P13627 题解

· · 题解

“以下至少一条为真”/tuu

第二条为真是同理的,只讲第一条为真咋做,注意最后要减去一根柱子的情况。

先考虑 a_1<a_2<a_3<\dots<a_ksum=n 的答案。不难发现 k\leq O(\sqrt n)。转化为每次将所有正数减一。考虑 dp_{i,j} 表示目前考虑的 sum=ia_{k-j+1}\sim a_k 还不为 0 的方案数,容易转移。

考虑把式子转化成 a_1<a_2<a_3<\dots<a_k 的形式。考虑容斥,将 > 容斥为无限制减去 \leq。因此,计算 a_1<a_2\leq a_3<a_4\leq a_5<a_6\leq \dotssum=n' 的方案数。同样可以 O(n\sqrt n) 计算。至于无限制,直接分治 NTT 合并即可。

注意有个小问题就是最后一个可能是 > 结尾,特判一下,NTT 卷出来之后再单独卷上包含最后一段的即可。

总复杂度 O(n\sqrt n+n\log n)

#include <bits/stdc++.h>
#define int long long
#define mid ((l+r)>>1)
#define lowbit(i) (i&(-i))
using namespace std;
const int mod=998244353;
inline void add(int &i,int j){
    i+=j;
    if(i>=mod) i-=mod;
}
inline int addv(int i,int j){
    i+=j;
    if(i>=mod) i-=mod;
    return i;
}
int dp[2][300005];
namespace Conv{
    typedef long long ll;
    const int mod=998244353,N=1050000,g=3,invg=(mod+1)/3;
    int wk[N+5],ta[N+5],tb[N+5];
    int Power(int x,int y){
        int r=1;
        while(y){
            if(y&1)r=1ll*r*x%mod;
            x=1ll*x*x%mod,y>>=1;
        }
        return r;
    }
    void DFT(int *a,int n){
        for(int i=n>>1;i;i>>=1){
            int w=Power(g,(mod-1)/(i<<1));
            wk[0]=1;
            for(int j=1;j<i;j++)wk[j]=1ll*wk[j-1]*w%mod;
            for(int j=0;j<n;j+=(i<<1)){
                for(int k=0;k<i;k++){
                    int x=a[j+k],y=a[i+j+k],z=x;
                    x+=y,(x>=mod&&(x-=mod)),a[j+k]=x;
                    z-=y,(z<0&&(z+=mod)),a[i+j+k]=1ll*z*wk[k]%mod;
                }
            }
        }
    }
    void IDFT(int *a,int n){
        for(int i=1;i<n;i<<=1){
            int w=Power(invg,(mod-1)/(i<<1));
            wk[0]=1;
            for(int j=1;j<i;j++)wk[j]=1ll*wk[j-1]*w%mod;
            for(int j=0;j<n;j+=(i<<1)){
                for(int k=0;k<i;k++){
                    int x=a[j+k],y=1ll*a[i+j+k]*wk[k]%mod,z=x;
                    x+=y,(x>=mod&&(x-=mod)),a[j+k]=x;
                    z-=y,(z<0&&(z+=mod)),a[i+j+k]=z;
                }
            }
        }
        for(int i=0,inv=Power(n,mod-2);i<n;i++)a[i]=1ll*a[i]*inv%mod;
    }
    vector<int> conv(vector<int> A,vector<int> B){
        int sa=A.size(),sb=B.size();
        vector<int> ret(sa+sb-1);
        int len=1;
        while(len<ret.size())len<<=1;
        for(int i=0;i<len;i++)ta[i]=tb[i]=0;
        for(int i=0;i<sa;i++)ta[i]=A[i];
        for(int i=0;i<sb;i++)tb[i]=B[i];
        DFT(ta,len),DFT(tb,len);
        for(int i=0;i<len;i++)ta[i]=1ll*ta[i]*tb[i]%mod;
        IDFT(ta,len);
        for(int i=0;i<ret.size();i++)ret[i]=ta[i];
        return ret;
    }
}
int qp(int a,int b){
    int ans=1;
    while(b){
        if(b&1) (ans*=a)%=mod;
        (a*=a)%=mod;
        b>>=1;
    }
    return ans;
}
int f[300005],g1[300005],g2[300005];
int p[300005],q[300005];
void dvqntt(int l,int r){
    if(l==r) return ;
    dvqntt(l,mid);
    vector<int> v1(mid-l+1),v2(r-l+1);
    for(int i=l;i<=mid;i++) v1[i-l]=p[i];
    for(int i=0;i<=r-l;i++) v2[i]=f[i];
    v1=Conv::conv(v1,v2);
    for(int i=mid+1;i<=r;i++) add(p[i],v1[i-l]);
    dvqntt(mid+1,r); 
}
signed main(){
//  freopen("test.in","r",stdin);
//  freopen("test.out","w",stdout);
//  ios::sync_with_stdio(false);
//  cin.tie(0),cout.tie(0);
    memset(dp,0,sizeof(dp));
    dp[0][0]=1;
    for(int j=1;j<=1600;j++){
        memset(dp[j&1],0,sizeof(dp[j&1]));
        if(j&1) for(int i=j;i<=300000;i++) dp[j&1][i]=addv(dp[j&1][i-j],dp[(j&1)^1][i-j]);
        else for(int i=0;i<=300000;i++) dp[j&1][i]=addv((i>=j)?dp[j&1][i-j]:0,dp[(j&1)^1][i]);
        for(int i=0;i<=300000;i++){
            int ex=dp[j&1][i];
            if(((j-1)>>1)&1) ex=mod-ex;
            if(j&1) add(g1[i],ex);
            else{
                if(i+j<=300000) add(f[i+j],ex);
            }
        }
    }
    memset(dp,0,sizeof(dp));
    dp[0][0]=1;
    for(int j=1;j<=1600;j++){
        memset(dp[j&1],0,sizeof(dp[j&1]));
        if(!(j&1)) for(int i=j;i<=300000;i++) dp[j&1][i]=addv(dp[j&1][i-j],dp[(j&1)^1][i-j]);
        else for(int i=0;i<=300000;i++) dp[j&1][i]=addv((i>=j)?dp[j&1][i-j]:0,dp[(j&1)^1][i]);
        for(int i=0;i<=300000;i++){
            int ex=dp[j&1][i];
            if(((j-1)>>1)&1) ex=mod-ex;
            if(j&1){
                if(i+j<=300000) add(g2[i+j],ex);
            }
        }
    }
//  for(int i=1;i<=10;i++) cout<<f[i]<<" "<<g1[i]<<" "<<g2[i]<<"\n";
    p[0]=1;
    dvqntt(0,300000);
    for(int i=0;i<=300000;i++) q[i]=addv(p[i],p[i]);
    vector<int> v1,v2;
    v1.resize(300001); v2.resize(300001);
    for(int i=0;i<=300000;i++) v1[i]=p[i],v2[i]=g1[i];
    v1=Conv::conv(v1,v2);
    for(int i=0;i<=300000;i++) add(q[i],v1[i]);
    v1.resize(300001); v2.resize(300001);
    for(int i=0;i<=300000;i++) v1[i]=p[i],v2[i]=g2[i];
    v1=Conv::conv(v1,v2);
    for(int i=0;i<=300000;i++) add(q[i],v1[i]);
    int t; cin>>t;
    while(t--){
        int n; cin>>n;
        cout<<(q[n]+mod-1)%mod<<"\n";
    }
    return 0;
}