独钓寒江雪(树形 dp/同构)
Judgelight · · 题解
传送门
考虑任意选一个点当根,设状态为
可以想到有
注意到对于两种同构的方案,一定在某一点处选择了两棵同构的不同子树,比如:
显然左右同构,因为
所以就有了一个思路:如果现在计算的是
最后注意一点:如果任意选的点能与其他点当根的树同构,答案就是错的。考虑在无根树中的特殊点重心,可以考虑选重心当根,这样就最多两个,而且相邻,直接特判即可。
#include<bits/stdc++.h>
#define N 500009
#define eb emplace_back
#define ll unsigned long long
using namespace std;
inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline int read(){ int res = 0; char c = nc(); while(c<'0'||c>'9')c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',c=nc(); return res; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(int x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); }
const int mod=1e9+7;
int ksm(int a,int b){
int ans=1;
while(b){
if(b&1)ans=1ll*ans*a%mod;
a=1ll*a*a%mod;
b>>=1;
}
return ans;
}
int n,Size[N],ss[N],rt[2],f[N][2],inv[N];
ll hs[N],a[N];
vector<int>V[N];
void dfs(int u,int from){
Size[u]=1;
for(auto v:V[u])if(v!=from)dfs(v,u),Size[u]+=Size[v],ss[u]=max(ss[u],Size[v]);
ss[u]=max(ss[u],n-Size[u]);
}
unordered_map<ll,int>mp,mp1;
void dfs1(int u,int from){
Size[u]=1;
for(auto v:V[u]){
if((u==rt[0]&&v==rt[1])||(u==rt[1]&&v==rt[0]))continue;
if(v!=from)dfs1(v,u),Size[u]+=Size[v],hs[u]+=a[Size[v]]*hs[v];
}
hs[u]++;hs[u]=hs[u]*Size[u];
}
int C(int n,int m){
int ans=inv[m];
for(int i=n-m+1;i<=n;i++)ans=1ll*ans*i%mod;
return ans;
}
void dfs2(int u,int from){
f[u][0]=f[u][1]=1;
for(auto v:V[u]){
if((u==rt[0]&&v==rt[1])||(u==rt[1]&&v==rt[0]))continue;
if(v!=from)dfs2(v,u);
}
mp.clear(),mp1.clear();
for(auto v:V[u]){
if((u==rt[0]&&v==rt[1])||(u==rt[1]&&v==rt[0]))continue;
if(v!=from)mp[hs[v]]++,mp1[hs[v]]=v;
}
for(auto x:mp){
f[u][0]=1ll*f[u][0]*C(x.second+f[mp1[x.first]][1]+f[mp1[x.first]][0]-1,x.second)%mod;
f[u][1]=1ll*f[u][1]*C(x.second+f[mp1[x.first]][0]-1,x.second)%mod;
}
}
signed main(){
srand(time(0));
n=read();
a[0]=1;for(int i=1;i<=n;i++)a[i]=a[i-1]*rand()*i%mod;
inv[1]=1;for(int i=2;i<=n;i++)inv[i]=1ll*inv[i-1]*ksm(i,mod-2)%mod;
for(int i=1,x,y;i<n;i++)x=read(),y=read(),V[x].eb(y),V[y].eb(x);
dfs(1,0);
rt[0]=1;
for(int i=2;i<=n;i++){
if(ss[i]<ss[rt[0]])rt[0]=i,rt[1]=0;
else if(ss[i]==ss[rt[0]])rt[1]=i;
}
if(rt[1]){
n++,V[n].eb(rt[0]),V[rt[0]].eb(n),V[n].eb(rt[1]),V[rt[1]].eb(n),dfs1(n,0),dfs2(n,0);
}
else{
dfs1(rt[0],0),dfs2(rt[0],0);
}
if(rt[1]){
if(hs[rt[0]]==hs[rt[1]])printf("%lld",(C(f[rt[0]][0]+1,2)+1ll*f[rt[0]][0]*f[rt[1]][1]%mod)%mod);
else printf("%lld",(1ll*f[rt[0]][0]*f[rt[1]][0]%mod+1ll*f[rt[0]][1]*f[rt[1]][0]%mod+1ll*f[rt[0]][0]*f[rt[1]][1]%mod)%mod);
}
else printf("%d",(f[rt[0]][0]+f[rt[0]][1])%mod);
return 0;
}