题解 P4491 【[HAOI2018]染色】

· · 题解

广告 : 炫酷反演魔术 & 多项式计数杂谈

二项式反演经典题。

G[k] 表示有恰好 k 种颜色出现了 S 次的方案数。

我们需要求出 W[1...n] 即可解决问题,这个 n=\min(M,N/S)

一个naive的想法是 : 钦定k种颜色,每种强行染上S次。

剩下的颜色放任自流,随便怎么染。

得到 \dbinom{m}{k}*\dfrac{n!}{(S!)^k(n-S*k)!}*(n-S*k)^{m-k}

组合意义 : "选 k 种颜色\times可重排列\times放任自流"

那么这个东西肯定不是 G[k] ,也并不是所谓 "至少 k 种颜色出现了 S 次的方案数"

事实上,这是会重复计数的,但是仍然以优美的方式蕴含了 G[k]

我们设其为 F[k]

类比至一般情况,得到 F[k]=\sum\limits_{i=k}^{n}\dbinom{i}{k}*G[i]

已知 F 要求 G ,不就是二项式反演的经典操作嘛!

反演可得 G[k]=\sum\limits_{i=k}^n(-1)^{i-k}\dbinom{i}{k}F[i]

到现在复杂度还是 O(n^2) 的,我们考虑把组合数拆开看看能否卷积:

G[k]=\sum\limits_{i=k}^n(-1)^{i-k}\dfrac{i!}{k!(i-k)!}F[i] G[k]*k!=\sum\limits_{i=k}^n\dfrac{(-1)^{i-k}}{(i-k)!}i!F[i]

这可以明显地看到差卷积形式了。

A[i]=i!*F[i],\ B[i]=\dfrac{(-1)^i}{i!}

则有G[k]=\frac{1}{k!}\sum\limits_{i=k}^nA[i]B[i-k]

NTT计算差卷积即可,复杂度O(N+n\log n)

Code:

#include<algorithm>
#include<cstdio>
#define ll long long
using namespace std;
const int _G=3,mod=1004535809,Maxn=135000,MaxNum=10000500;
ll powM(ll a,int t=mod-2){
  ll ans=1;
  while(t){
    if(t&1)ans=ans*a%mod;
    a=a*a%mod;t>>=1;
  }return ans;
}
const int invG=powM(_G);
int tr[Maxn<<1];
void NTT(int *g,bool op,int n)
{
  #define ull unsigned long long
  static ull f[Maxn<<1],w[Maxn]={1};
  for (int i=0;i<n;i++)f[i]=g[tr[i]];
  for(int l=1;l<n;l<<=1){
    ull tG=powM(op?_G:invG,(mod-1)/(l+l));
    for (int i=1;i<l;i++)w[i]=w[i-1]*tG%mod;
    for(int k=0;k<n;k+=l+l)
      for(int p=0;p<l;p++){
        int tt=w[p]*f[k|l|p]%mod;
        f[k|l|p]=f[k|p]+mod-tt;
        f[k|p]+=tt;
      }      
    if (l==(1<<10))
      for (int i=0;i<n;i++)f[i]%=mod;
  }if (!op){
    ull invn=powM(n);
    for(int i=0;i<n;++i)
      g[i]=f[i]%mod*invn%mod;
  }else for(int i=0;i<n;++i)g[i]=f[i]%mod;
}
int n,m,S,lim,limNum;
ll w[Maxn],fac[MaxNum],ifac[MaxNum];
void Init()
{
  limNum=max(n,m);fac[0]=1;
  for (int i=1;i<=limNum;i++)
    fac[i]=fac[i-1]*i%mod;
  ifac[limNum]=powM(fac[limNum]);
  for (int i=limNum;i;i--)
    ifac[i-1]=ifac[i]*i%mod;
}
ll C(int n,int m)
{return fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
inline ll clacF(int x){
  return 
    C(m,x)*fac[n]%mod*
    powM(ifac[S],x)%mod*ifac[n-S*x]%mod*
        powM(m-x,n-S*x)%mod;
}
int A[Maxn<<1],B[Maxn<<1];
int main()
{
  scanf("%d%d%d",&n,&m,&S);
  lim=min(m,n/S);Init();
  for (int i=0;i<=lim;i++){
    A[i]=clacF(i)*fac[i]%mod;
    B[i]=(i&1) ? mod-ifac[i] : ifac[i];
  }reverse(A,A+lim+1);
  for(n=1;n<lim+lim+2;n<<=1);
  for (int i=0;i<n;i++)
    tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
  NTT(A,1,n);NTT(B,1,n);
  for (int i=0;i<n;i++)A[i]=1ll*A[i]*B[i]%mod;
  NTT(A,0,n);reverse(A,A+lim+1);
  ll ans=0;
  for (int i=0,w;i<=lim;i++){
    scanf("%d",&w);
    ans=(ans+A[i]*ifac[i]%mod*w)%mod;
  }printf("%lld",ans);
  return 0;
}