题解:AT_abc422_g Balls and Boxes

· · 题解

思路

给出一个根号做法。

第一问是搞笑完全背包,做 3 次就行。

考虑第二问在求什么,假设三个盒子分别放了 x,y,z 个,求的就是 \sum \binom{n}{x}\times \binom{n-x}{y}\times \binom{n-x-y}{z}

我们化简一下这个式子,其实就是 \sum \dfrac{n!}{x!\times y!\times z!}

首先有一个显然的做法,枚举 x,y 硬算这个式子,复杂度是 O(\dfrac{n^2}{AB}) 的。

显然我们可以在 AB\ge \sqrt{n} 的时候做这个。

考虑 AB\le \sqrt{n} 怎么办,发现这时可以接受一个 O(nAB) 的东西。

于是我们设 f_{i,x,y} 表示当前放了 i 个球,第一个盒子放的球的个数模 Ax,第二个盒子放的球的个数模 By,且只往前两个盒子放球的方案数。

答案就是枚举一个 C 的倍数 z,求 \sum f_{n-z,0,0}\times \binom{n}{z}

于是做到了 O(n\sqrt{n})

代码

#include<bits/stdc++.h>
#define int long long
#define N 300005
#define mod 998244353
#define pii pair<int,int>
#define x first
#define y second
#define pct __builtin_popcount
#define mpi make_pair
#define inf 2e18
using namespace std;
int T=1,n,a[5],f[N],fac[N],inv[N];
vector<vector<int>>g[N];
int ksm(int x,int y){
    int res=1;
    while(y){
        if(y&1)(res*=x)%=mod;
        (x*=x)%=mod;
        y>>=1;
    }
    return res;
}
void init(){
    int n=3e5;
    fac[0]=inv[0]=1;
    for(int i=1;i<=n;i++){
        fac[i]=fac[i-1]*i%mod;
    }
    inv[n]=ksm(fac[n],mod-2);
    for(int i=n-1;i;i--){
        inv[i]=inv[i+1]*(i+1)%mod;
    }
}
int C(int n,int m){
    if(n<0||m<0||n<m)return 0;
    return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
void solve(int cs){
    if(!cs)return;
    cin>>n>>a[1]>>a[2]>>a[3];
    f[0]=1;
    for(int i=1;i<=3;i++){
        for(int j=a[i];j<=n;j++){
            (f[j]+=f[j-a[i]])%=mod;
        }
    }
    cout<<f[n]<<'\n';
    if(a[1]*a[2]>=sqrt(n)){
        int res=0;
        for(int i=0;i<=n/a[1];i++){
            for(int j=0;j<=(n-a[1]*i)/a[2];j++){
                int cur=n-i*a[1]-j*a[2];
                if(cur%a[3])continue;
                (res+=fac[n]*inv[i*a[1]]%mod*inv[j*a[2]]%mod*inv[cur]%mod)%=mod;
            }
        }
        cout<<res<<'\n';
    }
    else{
        for(int i=0;i<=n;i++){
            g[i].resize(a[1]);
            for(int j=0;j<a[1];j++){
                g[i][j].resize(a[2]);
            }
        }
        g[0][0][0]=1;
        for(int i=1;i<=n;i++){
            for(int j=0;j<a[1];j++){
                for(int k=0;k<a[2];k++){
                    int x=j-1,y=k-1;
                    if(x<0)x=a[1]-1;
                    if(y<0)y=a[2]-1;
                    (g[i][j][k]+=g[i-1][x][k]+g[i-1][j][y])%=mod;
                }
            }
        }
        int res=0;
        for(int c=0;c<=n/a[3];c++){
            (res+=C(n,c*a[3])*g[n-c*a[3]][0][0]%mod)%=mod;
        }
        cout<<res<<'\n';
    }
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    // cin>>T;
    init();
    for(int cs=1;cs<=T;cs++){
        solve(cs);
    }
    return 0;
}