[JSOI2019]神经网络

· · 题解

神仙题啊orz。树上背包+EGF。

考虑这个哈密顿回路可以拆分成什么。任意两棵树之间都可以到达,那么就可以拆分成一堆树的链的一个环,任意两个在环上相邻的链都不属于一棵树。首先求出每棵树拆分成j个链有多少种方案。注意一条长度大于1的链可以翻转过来,是两种方案。树上背包dp一下即可,dp式见代码,写的有点丑,请见谅。

然后发现如果就是链排列的话,直接把这m棵树的以链数为下标的EGF乘到一起就是答案的EGF,每一项系数加起来即可。然而这题相邻的颜色不可以一样,那么考虑容斥。假设我们求至少j对相邻的,我们就把钦定的这j对捏成一个大链,方便计算。那么dp式定义有了改变,变成了考虑前i棵树,有j个大链的带容斥系数方案数。注意是带容斥系数的,因为每棵树的容斥系数也要相乘才能得到这个拆分方案的总容斥系数。依然是有标号背包形式,也可以转化成EGF卷积。

第一棵树有一些特殊限定条件,先不考虑他。考虑其他树的EGF是什么。式子中的f_i表示这棵树拆分成i条链的方案数,已通过上文dp求出。j枚举的是我们钦定了多少对相邻颜色相同的链。这是除了第一棵树以外其他树的EGF:

\sum\limits_{i=1}^{k}f_i\times i!\sum\limits_{j=0}^{i-1}(-1)^jC_{i-1}^j \frac{x^{i-j}}{(i-j)!}

首先拆分成i条链方案数为f_i,内部排列是i!,然后容斥系数为(-1)^j,在这i条链中,有i-1个空隙,从中选j个钦定其为相邻,方案数为C_{i-1}^ji条链j对相邻,那么就有i-j条大链,所以是EGF的第i-j项。

然后考虑第一棵树的限制怎么处理。第一个限制:第一条链必须属于第一棵树,且包含1号节点。那么就钦定第一棵树的包含1号节点的链是第一个即可。他不参与内部排列,不参与外部的有标号卷积,所以拆分成i条链内部排列变成(i-1)!,变成EGF的第i-j-1

\sum\limits_{i=1}^{k}f_i\times (i-1)!\sum\limits_{j=0}^{i-1}(-1)^jC_{i-1}^j \frac{x^{i-j-1}}{(i-j-1)!}

第二个限制:最后一条链不可以属于第一棵树。那么上面的EGF是不考虑第二个限制的,把不符合第二个限制的部分减掉即可。也就是说我们钦定最后一条链属于第一棵树,把这个EGF减掉即可。最后一条链属于第一棵树的话,在这i条链里面任选一个不参与外部的有标号卷积,但是内部排列他还要参加,因为任何一条属于第一棵树的链都可以做最后一条链,所以内部排列还是(i-1)!,变成EGF的第i-j-2项。

\sum\limits_{i=1}^{k}f_i\times (i-1)!\sum\limits_{j=0}^{i-2}(-1)^jC_{i-1}^j \frac{x^{i-j-2}}{(i-j-2)!}

假如说j=i-1的话,相当于是这棵树的所有链都挨在一起,且包含第一条链和最后一条链,这等价于一共只有1棵树。这种情况下i-j-2出现了-1项,是没有定义的。所以说严格来讲这题应该限制m>1,否则这个做法是有问题的。把所有树的EGF都卷起来,得到的EGF别忘了每一项要乘以i!,才是有标号背包转移得到的结果。把所有答案加起来就是容斥出来的结果。

真的是道非常好的题,认真思考会对EGF理解加深很多。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
using namespace std;
#define N 300002
typedef long long ll;
const int p=998244353;
struct edge{int to,nxxt;}e[N<<1];
int m,k[302],head[N],cnt=1,dp[5002][5002][3],siz[N],t[5002][3],ct[5002];
int inv[N],fac[N],finv[N],f[N],g[N],py[N];
inline void ins(int u,int v){e[cnt].to=v;e[cnt].nxxt=head[u];head[u]=cnt++;}
void df5(int te,int la)
{siz[te]=1;dp[te][1][0]=1;
    for(int i=head[te];i;i=e[i].nxxt)
    {
        int j=e[i].to;if(j==la)continue;
        df5(j,te);
        for(int ii=1;ii<=siz[te]+siz[j];ii++)t[ii][0]=t[ii][1]=t[ii][2]=0;
        for(int ii=1;ii<=siz[te];ii++)for(int jj=1;jj<=siz[j];jj++)
        {
            int tdp=(1ll*dp[j][jj][0]+2ll*dp[j][jj][1]%p+dp[j][jj][2])%p;
            (t[ii+jj][0]+=1ll*dp[te][ii][0]*tdp%p)%=p;
            (t[ii+jj][1]+=1ll*dp[te][ii][1]*tdp%p)%=p;
            (t[ii+jj][2]+=1ll*dp[te][ii][2]*tdp%p)%=p;
            (t[ii+jj-1][1]+=1ll*dp[te][ii][0]*dp[j][jj][0]%p)%=p;
            (t[ii+jj-1][1]+=1ll*dp[te][ii][0]*dp[j][jj][1]%p)%=p;
            (t[ii+jj-1][2]+=2ll*dp[te][ii][1]*(dp[j][jj][0]+dp[j][jj][1])%p)%=p;
        }siz[te]+=siz[j];
        for(int ii=1;ii<=siz[te];ii++)
        dp[te][ii][0]=t[ii][0],dp[te][ii][1]=t[ii][1],dp[te][ii][2]=t[ii][2];
    }
}
inline int C(int nn,int mm)
{
    if(nn<mm)return 0;
    if(mm==0||nn==mm)return 1;
    return 1ll*fac[nn]*finv[mm]%p*finv[nn-mm]%p;
}
int main()
{//freopen("7.in","r",stdin);
    fac[0]=finv[0]=inv[1]=fac[1]=finv[1]=1;
    for(int i=2;i<=100000;i++)
    {
        inv[i]=1ll*(p-p/i)*inv[p%i]%p;
        fac[i]=1ll*fac[i-1]*i%p;finv[i]=1ll*finv[i-1]*inv[i]%p;
    }
    scanf("%d",&m);f[0]=1;int sum=0;
    for(int i=1;i<=m;i++)
    {
        scanf("%d",&k[i]);
        for(int j=1;j<=k[i];j++)head[j]=ct[j]=0;cnt=1;
        for(int j=1;j<k[i];j++)
        {
            int x,y;scanf("%d%d",&x,&y);
            ins(x,y),ins(y,x);
        }
        for(int ii=1;ii<=k[i];ii++)for(int j=1;j<=k[i];j++)dp[ii][j][0]=dp[ii][j][1]=dp[ii][j][2]=0;
        df5(1,1);for(int j=1;j<=k[i];j++)ct[j]=(1ll*dp[1][j][0]+2ll*dp[1][j][1]%p+dp[1][j][2])%p;
        for(int j=0;j<=k[i];j++)g[j]=0;
        //for(int j=1;j<=k[i];j++)printf("%d ",ct[j]);puts("");
        if(i^1)
        {
            for(int ii=1;ii<=k[i];ii++)
            for(int j=0;j<ii;j++)
            {
                if(j&1)g[ii-j]=(g[ii-j]-1ll*ct[ii]*fac[ii]%p*C(ii-1,j)%p+p)%p;
                else g[ii-j]=(g[ii-j]+1ll*ct[ii]*fac[ii]%p*C(ii-1,j)%p)%p;
            }
        }
        else
        {
            for(int ii=1;ii<=k[i];ii++)for(int j=0;j<ii;j++)
            {
                if(j&1)g[ii-j-1]=(g[ii-j-1]-1ll*ct[ii]*fac[ii-1]%p*C(ii-1,j)%p+p)%p;
                else g[ii-j-1]=(g[ii-j-1]+1ll*ct[ii]*fac[ii-1]%p*C(ii-1,j)%p)%p;
            }
            for(int ii=1;ii<=k[i];ii++)for(int j=0;j<ii-1;j++)
            {
                if(j&1)g[ii-j-2]=(g[ii-j-2]+1ll*ct[ii]*fac[ii-1]%p*C(ii-1,j)%p)%p;
                else g[ii-j-2]=(g[ii-j-2]-1ll*ct[ii]*fac[ii-1]%p*C(ii-1,j)%p+p)%p;
            }
        }
        for(int ii=0;ii<=k[i];ii++)g[ii]=1ll*g[ii]*finv[ii]%p;
        for(int ii=0;ii<=sum+k[i];ii++)py[ii]=0;
        for(int ii=0;ii<=sum;ii++)for(int jj=0;jj<=k[i];jj++)
        (py[ii+jj]+=1ll*f[ii]*g[jj]%p)%=p;
        sum+=k[i];
        for(int ii=0;ii<=sum;ii++)f[ii]=py[ii];
    }
    int ans=0;
    for(int ii=0;ii<=sum;ii++)ans=(ans+1ll*f[ii]*fac[ii]%p)%p;
    printf("%d\n",ans);
}
/*
3
4
2 1
4 3
3 1
5
4 5
3 4
3 2
1 2
1
*/