题解 P5860 【「SWTR-03」Counting Trees】

· · 题解

\color{blue}\text{(重制于 2020.11.26)}

由于时间关系没有现场打,赛后来肝题。

看到数树以为是大毒瘤,没想到这树就是来开玩笑的……

树的必要条件 : \text{度数和}=2(\text{点数}-1)

若度数符合要求,根据 \rm prufer 序列,必然能够造出符合要求的树。则上面的是充要条件。

移项整理得到 \text{度数和}-2\text{点数}=-2

我们每加一个点扣除 2 点贡献,每有一个度数加上 1 点贡献,具体处理方式就是 \rm OGF

全部卷积起来之后,[x^{-2}] 项就是答案。

(我在CSP2019曾因为没看出这个trick白丢了16分qwq)

也就是说我们需要求 [x^{-2}]\prod\limits_{i=1}^n(1+x^{v_i-2})

怎么还有负数次项啊,又没法分治NTT,一脸不可做的样子……

首先对于 v_i=2 的点单独提出来(次数为0),设有 m 个,最后答案乘以 2^m

首先对于 v^i<2 的部分,只可能有 (1+x^{-1}) 。这部分直接记录次数 c ,二项式定理即可。

得到 F1(x) ,最低有 -c 次项,注意到正负项可能相消,我们后面正次数的多项式保留 \bmod\ x^{c-1} 的项即可。

对于 v_i>2 的部分,无法分治NTT。

a_i=v_i-2 (此时余下的所有 a_i>0 )

{\rm Ans}=\prod\limits_{i=1}^n(1+x^{a_i})

类似 P4389 考虑将每个东西 \ln 之后加在一起再 \exp 回来。

不难发现 \ln(1+x^a) 只在 a 的倍数次数处有值。

具体而言,根据 -\ln(1-x)=\sum\limits_{i=1}\dfrac{x^i}{i}

可推知 \ln(1+x^a)=\sum\limits_{i=0}\dfrac{(-x)^{ia}}{i}

需要累加的项是 O\Big(\sum\limits_{i=1}^nn/i\Big)=O(n\log n) 的。

剩下的就是 P4389 了,得到这部分的 F2(x)

最后把 F1,F2 卷起来就是答案,负指数加法卷积的规则和正指数类似。

本题中可以把 F1 整体平移来处理。

#include<algorithm>
#include<cstdio>
#include<ctime>
#define ll long long
#define mod 998244353
#define G 3
#define Maxn 530000
using namespace std;
inline int read()
{
  register int X=0;
  register char ch=0;
  while(ch<48||ch>57)ch=getchar();
  while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
  return X;
}
ll powM(ll a,ll t=mod-2)
{
  ll ans=1;
  while(t){
    if(t&1)ans=ans*a%mod;
    a=a*a%mod;
    t>>=1;
  }return ans;
}
int tr[Maxn<<1];
ll invG=powM(G);
void NTT(ll *f,short op,int n)
{
  for (int i=0;i<n;i++)
    if (i<tr[i])swap(f[i],f[tr[i]]);
  for(int p=2;p<=n;p<<=1){
    int len=p>>1,
        tG=powM(op==1 ? G:invG,(mod-1)/p);
    for (int k=0;k<n;k+=p){
      ll buf=1;
      for (int i=k;i<k+len;i++){
        int sav=f[len+i]*buf%mod;
        f[len+i]=(f[i]-sav+mod)%mod;
        f[i]=(f[i]+sav)%mod;
        buf=buf*tG%mod;
      }
    }
  }
}
ll _g1[Maxn<<1];
void times(ll *f,ll *g,int len1,int len2,int lim)
{
  int m=len1+len2-1,n;
  for(int i=0;i<len2;i++)_g1[i]=g[i];
  #define g _g1
  for(n=1;n<m;n<<=1);
  for(int i=len2;i<n;i++)g[i]=0;
  ll invn=powM(n);
  for(int i=0;i<n;i++)
    tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
  NTT(f,1,n);NTT(g,1,n);
  for(int i=0;i<n;++i)f[i]=f[i]*g[i]%mod;
  NTT(f,-1,n);
  for(int i=0;i<lim;++i)
    f[i]=f[i]*invn%mod;
  for(int i=lim;i<n;++i)f[i]=0;
  #undef g
}
ll _w2[Maxn<<1],_r2[Maxn<<1];
void invp(ll *f,int m)
{
  int n;for (n=1;n<m;n<<=1);
  #define w _w2
  #define r _r2
  w[0]=powM(f[0]);
  for (int len=2;len<=n;len<<=1){
    for (int i=0;i<(len>>1);i++)
      r[i]=(w[i]<<1)%mod;
    times(w,w,len>>1,len>>1,len);
    times(w,f,len,len,len);
    for (int i=0;i<len;i++)
      w[i]=(r[i]-w[i]+mod)%mod;
  }for (int i=0;i<m;i++)f[i]=w[i];
  for (int i=0;i<n;i++)w[i]=r[i]=0;
  #undef w
  #undef r
}
ll fac[Maxn],inv[Maxn],inp[Maxn];
ll C(int n,int m)
{return fac[n]*inv[m]%mod*inv[n-m]%mod;}
void Init(int m)
{
  int lim=1;
  for (;lim<m;lim<<=1)
  fac[0]=1;
  for (int i=1;i<=lim;i++)
    fac[i]=fac[i-1]*i%mod;
  inv[lim]=powM(fac[lim]);
  for (int i=lim;i;i--){
    inv[i-1]=inv[i]*i%mod;
    inp[i]=inv[i]*fac[i-1]%mod;
  }
}
void dao(ll *f,int m)
{
  for (int i=1;i<m;i++)
   f[i-1]=f[i]*i%mod;
  f[m-1]=0;
}
void jifen(ll *f,int m)
{
  for (int i=m;i;i--)
    f[i]=f[i-1]*powM(i)%mod;
  f[0]=0;
}
ll _s3[Maxn<<2];
void lnp(ll *f,int m)
{
  ll *sav=_s3;
  for (int i=0;i<m;i++)sav[i]=f[i];
  invp(sav,m);dao(f,m);
  times(f,sav,m-1,m,m);
  jifen(f,m-1);
  for (int i=0;i<m;i++)sav[i]=0;
}
ll _xp[Maxn<<2],_xp2[Maxn<<2];
void exp(ll *f,int m)
{
  ll *s=_xp,*s2=_xp2;
  int n=1;for(;n<m;n<<=1);
  s2[0]=1;
  for (int len=2;len<=n;len<<=1){
    for (int i=0;i<(len>>1);i++)s[i]=s2[i];
    lnp(s,len);
    for (int i=0;i<len;i++)
      s[i]=(f[i]-s[i]+mod)%mod;
    s[0]=(s[0]+1)%mod;
    times(s2,s,len>>1,len,len);
  }for (int i=0;i<m;i++)f[i]=s2[i];
  for (int i=0;i<n;i++)s[i]=s2[i]=0;
}
ll buf,F1[Maxn],F2[Maxn];
int n,c[Maxn],m;
int main()
{
  n=read();
  for (int i=1;i<=n;i++)c[read()]++;
  m=c[1];
  if (m<2){printf("0");return 0;}
  Init(m);
  for (int i=0;i<m-1;i++)F1[i]=C(m,i);
  m--;
  for (int i=3;i<=n;i++){
    ll a=i-2;
    for (int j=1;j*a<m;j++)
      F2[j*a]=(F2[j*a]+1ll*c[i]*((j&1) ? inp[j] : mod-inp[j]))%mod;
  }
  exp(F2,m);
  times(F1,F2,m,m,m);
  printf("%lld",F1[m-1]*powM(2,c[2])%mod);
  return 0;
}