题解 P5339 【[TJOI2019]唱、跳、rap和篮球】
首先转换一下问题:
我们求有多少个排布方案,满足至少有
怎么求呢?
我们可以枚举有
为什么呢?
第一个方法是整体法(感谢@千年之狐_天才 大佬)
首先,对于每一种方案,有
我们可以考虑枚举这些没有被选中的位置。
把每一个讨论鸡你太美的组看成一个整体,缩成一个点。
这样就有
对于所有
否则就代表一个人。
我们直接从这
这样方案数就是
显然这样的枚举对应的方案是唯一的(可以把这些选为组的点展开,再顺序标号)。
第二个方法是转化法。
可以考虑转化一下问题。这等价于求
设
所以我们枚举
所以答案就等于
然后这么多位置已经固定了,怎么计算剩余不讨论鸡你太美的人的排列数呢???
真心难算!因为可能有些排列会有不只
所以我们考虑用容斥原理,枚举有
这样就可以排除干扰,对剩下的乱排列了。
设初始
答案
这个为什么对呢?
发现枚举至少一组的时候,对于一种可行的方案(这里代指枚举方案)会算
枚举至少两组的时候,会算
枚举至少i 组的时候,会算C_j^i 次至少j组的贡献(j\geq i)
所以我们可以通过容斥
来算出单个的贡献。这可以通过二项式展开来证明。
所以答案
设喜欢
这时候分别还剩下
相当于求有重复元素的排列!
我们知道,如果
那么排列答案就是
如果
如果
这就是一个四元卷积啦!
维护
模数真良心
我们前面还要用所有排列的个数减去答案,所以真正的答案其实就是
一个标准的容斥式子。
数学真是美妙啊!!!
对了,复杂度是
代码如下:
#include<bits/stdc++.h>
#define ll long long
#define ljc 998244353
using namespace std;
ll n,m,r[40006],lim;
ll a[20001],w[20001],b[20001],mini,num[4],c[20001],d[20001],fac[20010],inv[20001];
inline ll fast_pow(ll a,ll b,ll p){
ll t=1;a%=p;
while (b){
if (b&1) t=t*a%p;
b>>=1;a=a*a%p;
}
return t;
}
inline void NTT(ll f[],int lim,int id){
for (int i=0;i<lim;i++){
if (i<r[i]) swap(f[r[i]],f[i]);
}
w[0]=1;
for (int len=1;len<lim;len<<=1){
ll gen=fast_pow(3,(ljc-1)/(len<<1)*id+ljc-1,ljc);
for (int i=1;i<len;i++) w[i]=w[i-1]*gen%ljc;
for (int i=0;i<lim;i+=len<<1){
ll *f1=f+i,*f2=f1+len;
for (int j=0;j<len;j++){
ll x=f1[j],y=f2[j]*w[j]%ljc;
f1[j]=(x+y)%ljc;
f2[j]=(x-y+ljc)%ljc;
}
}
}
if (id==1) return;
ll Inv=fast_pow(lim,ljc-2,ljc);
for (int i=0;i<lim;i++) f[i]=f[i]*Inv%ljc;
}
inline ll P(ll n,ll A,ll B,ll C,ll D){
if (n>A+B+C+D) return 0;
if (n<0) return 0;
ll lim=1,len=0;
while (lim<(A+B+C+D<<1)) lim<<=1,len++;
for (int i=0;i<lim;i++) r[i]=(r[i>>1LL]>>1LL)|(((1LL*i)&1)<<(len-1LL));
for (int i=0;i<lim;i++) a[i]=(i<=A)?inv[i]:0;
for (int i=0;i<lim;i++) b[i]=(i<=B)?inv[i]:0;
for (int i=0;i<lim;i++) c[i]=(i<=C)?inv[i]:0;
for (int i=0;i<lim;i++) d[i]=(i<=D)?inv[i]:0;
NTT(a,lim,1);NTT(b,lim,1);NTT(c,lim,1);NTT(d,lim,1);
for (int i=0;i<lim;i++) a[i]=a[i]*b[i]%ljc*c[i]%ljc*d[i]%ljc;
NTT(a,lim,-1);
return fac[n]*a[n]%ljc;
}
inline ll C(ll m,ll n){
if (m<n) return 0;
return (fac[m]*inv[n]%ljc*inv[m-n]%ljc);
}
inline void init(int n){
n*=2;
inv[0]=inv[1]=1;fac[0]=1;
for (int i=1;i<=n;i++) fac[i]=fac[i-1]*i%ljc;
for (int i=2;i<=n;i++) inv[i]=(ljc-(ljc/i)*inv[ljc%i]%ljc)%ljc;
for (int i=2;i<=n;i++) inv[i]=inv[i-1]*inv[i]%ljc;
}
int main(){
cin>>n>>num[0]>>num[1]>>num[2]>>num[3];
mini=min(num[0],min(num[1],min(num[2],num[3])));
init(n);
ll ans=0,one=1;
for (int i=0;i<=min(mini,n/4);i++){
ans=(ans+one*C(n-3*i,i)%ljc*P(n-4*i,num[0]-i,num[1]-i,num[2]-i,num[3]-i)%ljc)%ljc;
one=ljc-one;
}
cout<<ans;
return 0;
}