题解:P11152 [THUWC 2018] 七彩序列
masterhuang · · 题解
-
特判
a_i 全相等,此时ans=0 。 -
记
\min a_i=l,\max a_i=r ,此时l<r 。 -
下面所有大写字母代表小写字母的生成函数。
第一步我就没注意到,其实很关键:
同时有前缀和后缀不好做,我们 把后缀的限制给弄到前缀上。
具体的,称一个前缀不合法,当且仅当这个前缀每个数出现的次数为
后者意义是填完这个前缀会让剩下的后缀不合法。
现在是一个标准的容斥形式,转换题意:
则
:::info[代码]
// 洛谷 P11152
// https://www.luogu.com.cn/problem/P11152
#include<bits/stdc++.h>
#define LL long long
#define fr(x) freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);
using namespace std;
const int N=2e7+5,M=1<<21|5,mod=998244353;
int n,l,r,u,a[105],jc[N],inv[N],_[M],w1[M],w2[M],w3[M],w[M],ans,U,f[M],g[N];
inline int bger(int x){return x|=x>>1,x|=x>>2,x|=x>>4,x|=x>>8,x|=x>>16,x+1;}
inline int md(int x){return x>=mod?x-mod:x;}
inline int ksm(int x,int p){int s=1;for(;p;(p&1)&&(s=1ll*s*x%mod),x=1ll*x*x%mod,p>>=1);return s;}
inline void init(int U)
{
for(int i=1,j,k;i<U;i<<=1)
for(w[j=i]=1,k=ksm(3,(mod-1)/(i<<1)),j++;j<(i<<1);j++)
w[j]=1ll*w[j-1]*k%mod;
}
inline void DNT(int *a,int U)
{
for(int i,j,k=U>>1,L,*W,*x,*y,z;k;k>>=1)
for(L=k<<1,i=0;i<U;i+=L)
for(j=0,W=w+k,x=a+i,y=x+k;j<k;j++,W++,x++,y++)
*y=1ll*(*x+mod-(z=*y))* *W%mod,*x=md(*x+z);
}
inline void IDNT(int *a,int U)
{
for(int i,j,k=1,L,*W,*x,*y,z;k<U;k<<=1)
for(L=k<<1,i=0;i<U;i+=L)
for(j=0,W=w+k,x=a+i,y=x+k;j<k;j++,W++,x++,y++)
z=1ll* *W* *y%mod,*y=md(*x+mod-z),*x=md(*x+z);
reverse(a+1,a+U);
for(int inv=ksm(U,mod-2),i=0;i<U;i++) a[i]=1ll*a[i]*inv%mod;
}
void INV(int num,int *a,int *b)
{
if(num==1) return b[0]=ksm(a[0],mod-2),void();
INV((num+1)>>1,a,b);int U=bger(num<<1);static int c[N];
for(int i=0;i<num;i++) c[i]=a[i];for(int i=num;i<U;i++) c[i]=0;
DNT(c,U);DNT(b,U);
for(int i=0;i<U;i++) b[i]=1ll*(2-1ll*c[i]*b[i]%mod+mod)%mod*b[i]%mod;
IDNT(b,U);for(int i=num;i<U;i++) b[i]=0;
}
int main()
{
cin.tie(0)->sync_with_stdio(0);cin>>n;
for(int i=1;i<=n;i++) cin>>a[i],u+=a[i];
sort(a+1,a+1+n);l=a[1],r=a[n];
if(l==r) return cout<<0,0;
for(int i=*jc=1;i<=u;i++) jc[i]=1ll*jc[i-1]*i%mod;
inv[u]=ksm(jc[u],mod-2);for(int i=u;i;i--) inv[i-1]=1ll*inv[i]*i%mod;
for(int i=0,s,S;i<=l;i++)
{
_[i]=w1[i]=1ll*jc[i*n]*ksm(inv[i],n)%mod;
if(i>=r-l)
{
s=0,S=1;
for(int j=1;j<=n;j++) S=1ll*S*inv[i+l-a[j]]%mod,s+=i+l-a[j];
w2[i]=1ll*S*jc[s]%mod;
}
s=0,S=1;
for(int j=1;j<=n;j++) S=1ll*S*inv[i-l+a[j]]%mod,s+=i-l+a[j];
w3[i]=1ll*S*jc[s]%mod;
}w1[0]=0;
init(U=bger((l+5)<<1));
DNT(w1,U);DNT(w2,U),DNT(w3,U);
for(int i=0;i<U;i++) f[i]=(1ll*w1[i]*w1[i]+1ll*(mod-w2[i])*w3[i])%mod;
IDNT(f,U);fill(f+l+1,f+U,0);f[0]++;
for(int i=1;i<=l;i++) f[i]=(f[i]+2ll*_[i])%mod;
INV(l+1,f,g);DNT(g,U);
for(int i=0;i<U;i++) g[i]=1ll*g[i]*w3[i]%mod;
IDNT(g,U);
return cout<<g[l],0;
}
:::