独钓寒江雪(树形 dp/同构)

· · 题解

传送门

考虑任意选一个点当根,设状态为 f_{u,0/1} 表示 u 的子树,选择/不选 u不同构方案数。

可以想到有 f_{u,0}=\prod\limits_{v\in son_u}(f_{v,0}+f_{v,1})f_{u,1}=\prod\limits_{v\in son_u}f_{v,0},但是这样算肯定是错的,所以考虑枚举 v 的时候每一个同构类一起讨论(判断同构可以写树哈希,如果被卡了考虑在上面加一点类似深度和儿子个数之类的)。

注意到对于两种同构的方案,一定在某一点处选择了两棵同构的不同子树,比如:

显然左右同构,因为 2 子树和 3 子树同构。

所以就有了一个思路:如果现在计算的是 f_{u,0},假设当前同构类有 x 个点,这 x 个点的 f 均为 g_{0/1}(因为同构所以 f 相等),相当于可重复地从 g_{0}+g_{1} 中选择 x 种进行组合,方案数为 C_{g_0+g_1+x-1}^{x}(隔板法可证)。最后 f_{u,0} 就是所有等价类答案的积,f_{u,1} 是同理的。

最后注意一点:如果任意选的点能与其他点当根的树同构,答案就是错的。考虑在无根树中的特殊点重心,可以考虑选重心当根,这样就最多两个,而且相邻,直接特判即可。

#include<bits/stdc++.h>
#define N 500009
#define eb emplace_back
#define ll unsigned long long
using namespace std;
inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline int read(){ int res = 0; char c = nc(); while(c<'0'||c>'9')c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',c=nc(); return res; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(int x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); }
const int mod=1e9+7;
int ksm(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1ll*ans*a%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return ans;
}
int n,Size[N],ss[N],rt[2],f[N][2],inv[N];
ll hs[N],a[N];
vector<int>V[N];
void dfs(int u,int from){
    Size[u]=1;
    for(auto v:V[u])if(v!=from)dfs(v,u),Size[u]+=Size[v],ss[u]=max(ss[u],Size[v]);
    ss[u]=max(ss[u],n-Size[u]);
}
unordered_map<ll,int>mp,mp1;
void dfs1(int u,int from){
    Size[u]=1;
    for(auto v:V[u]){
        if((u==rt[0]&&v==rt[1])||(u==rt[1]&&v==rt[0]))continue;
        if(v!=from)dfs1(v,u),Size[u]+=Size[v],hs[u]+=a[Size[v]]*hs[v];
    }
    hs[u]++;hs[u]=hs[u]*Size[u];
}
int C(int n,int m){
    int ans=inv[m];
    for(int i=n-m+1;i<=n;i++)ans=1ll*ans*i%mod;
    return ans;
}
void dfs2(int u,int from){
    f[u][0]=f[u][1]=1;
    for(auto v:V[u]){
        if((u==rt[0]&&v==rt[1])||(u==rt[1]&&v==rt[0]))continue;
        if(v!=from)dfs2(v,u);
    }
    mp.clear(),mp1.clear();
    for(auto v:V[u]){
        if((u==rt[0]&&v==rt[1])||(u==rt[1]&&v==rt[0]))continue;
        if(v!=from)mp[hs[v]]++,mp1[hs[v]]=v;
    }
    for(auto x:mp){
        f[u][0]=1ll*f[u][0]*C(x.second+f[mp1[x.first]][1]+f[mp1[x.first]][0]-1,x.second)%mod;
        f[u][1]=1ll*f[u][1]*C(x.second+f[mp1[x.first]][0]-1,x.second)%mod;
    }
}
signed main(){
    srand(time(0));
    n=read();
    a[0]=1;for(int i=1;i<=n;i++)a[i]=a[i-1]*rand()*i%mod;
    inv[1]=1;for(int i=2;i<=n;i++)inv[i]=1ll*inv[i-1]*ksm(i,mod-2)%mod;
    for(int i=1,x,y;i<n;i++)x=read(),y=read(),V[x].eb(y),V[y].eb(x);
    dfs(1,0);
    rt[0]=1;
    for(int i=2;i<=n;i++){
        if(ss[i]<ss[rt[0]])rt[0]=i,rt[1]=0;
        else if(ss[i]==ss[rt[0]])rt[1]=i;
    }
    if(rt[1]){
        n++,V[n].eb(rt[0]),V[rt[0]].eb(n),V[n].eb(rt[1]),V[rt[1]].eb(n),dfs1(n,0),dfs2(n,0);
    }
    else{
        dfs1(rt[0],0),dfs2(rt[0],0);
    }
    if(rt[1]){
        if(hs[rt[0]]==hs[rt[1]])printf("%lld",(C(f[rt[0]][0]+1,2)+1ll*f[rt[0]][0]*f[rt[1]][1]%mod)%mod);
        else printf("%lld",(1ll*f[rt[0]][0]*f[rt[1]][0]%mod+1ll*f[rt[0]][1]*f[rt[1]][0]%mod+1ll*f[rt[0]][0]*f[rt[1]][1]%mod)%mod);
    }
    else printf("%d",(f[rt[0]][0]+f[rt[0]][1])%mod);
    return 0;
}