造树据
由于
考虑树形 dp。对于子树
考虑统计答案,由于题意中给出的是无根树,所以要换根统计
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N=5e5;
const int mod=998'244'353;
struct modint{
int v;
modint():v(0){}
modint(int x):v(x){}
modint operator+(modint o){
int t=v+o.v;
if(t>=mod){
t-=mod;
}
return {t};
}
modint operator-(modint o){
int t=v-o.v;
if(t<0){
t+=mod;
}
return {t};
}
modint operator*(modint o){
return {1ll*v*o.v%mod};
}
};
struct IO{
int n;
int u[N+2],v[N+2];
void init(){
cin>>n;
for(int i=1;i<n;i++){
cin>>u[i]>>v[i];
}
}
modint ans;
void output(){
cout<<ans.v<<"\n";
}
};
modint qpow(modint a,ll b){
modint ans=1;
while(b){
if(b&1){
ans=ans*a;
}
a=a*a;
b>>=1;
}
return ans;
}
struct Comb{
modint fac[N+2],inv[N+2];
Comb(){
fac[0]=1;
for(int i=1;i<=N;i++){
fac[i]=fac[i-1]*modint(i);
}
inv[N]=qpow(fac[N],mod-2);
for(int i=N-1;i>=0;i--){
inv[i]=inv[i+1]*modint(i+1);
}
}
};
struct solution{
IO io;
solution(IO io):io(io){}
vector<int> e[N+2];
void prework(){
for(int i=1;i<io.n;i++){
e[io.u[i]].push_back(io.v[i]);
e[io.v[i]].push_back(io.u[i]);
}
}
ull h(ull x){
return x*x*123+x*567+8910;
}
ull g(ull x){
return h(x&((1ull<<32)-1))+h(x>>32);
}
int siz[N+2];
modint f[N+2];
ull Hash[N+2];
map<ull,int> mp[N+2];
Comb c;
void dfs1(int u,int fa){
siz[u]=1;
f[u]=1;
Hash[u]=1;
for(int v:e[u]) if(v!=fa){
dfs1(v,u);
siz[u]+=siz[v];
f[u]=f[u]*f[v]*c.inv[siz[v]];
Hash[u]+=g(Hash[v]);
f[u]=f[u]*qpow(++mp[u][Hash[v]],mod-2);
}
f[u]=f[u]*c.fac[siz[u]-1];
}
map<ull,int> flag;
void dfs2(int u,int fa){
if(!flag[Hash[u]]){
io.ans=io.ans+f[u];
flag[Hash[u]]=1;
}
for(int v:e[u]) if(v!=fa){
f[u]=f[u]*(mp[u][Hash[v]]--);
f[u]=f[u]*qpow(f[v],mod-2)*c.fac[siz[v]]*c.inv[io.n-1]*c.fac[io.n-siz[v]-1];
Hash[u]-=g(Hash[v]);
Hash[v]+=g(Hash[u]);
f[v]=f[v]*f[u]*c.inv[io.n-siz[v]]*c.inv[siz[v]-1]*c.fac[io.n-1];
f[v]=f[v]*qpow((++mp[v][Hash[u]]),mod-2);
dfs2(v,u);
f[v]=f[v]*(mp[v][Hash[u]]--);
f[v]=f[v]*qpow(f[u],mod-2)*c.fac[io.n-siz[v]]*c.fac[siz[v]-1]*c.inv[io.n-1];
Hash[v]-=g(Hash[u]);
Hash[u]+=g(Hash[v]);
f[u]=f[u]*f[v]*c.inv[siz[v]]*c.fac[io.n-1]*c.inv[io.n-siz[v]-1];
f[u]=f[u]*qpow((++mp[u][Hash[v]]),mod-2);
}
}
void solve(){
prework();
dfs1(1,0);
io.ans=0;
dfs2(1,0);
io.ans=io.ans*c.inv[io.n-1];
}
};
int main(){
// freopen("data.in","r",stdin);
cin.tie(0)->sync_with_stdio(0);
IO io;
io.init();
solution s(io);
s.solve();
s.io.output();
return 0;
}