造树据

· · 题解

由于 fa_u<u,显然原问题可以转化为给树上对应结点标号。(即树上拓扑序)

考虑树形 dp。对于子树 u,编号最小的肯定是根结点。剩下的 siz_u-1 个编号要依次分给每个子树,则方案数为 \frac{(siz_u-1)!}{\Pi_{v\in son_u}siz_v!}。但是我们注意到,对于两个同构的子树,显然,他们之间的编号分配是等价的。所以要用树哈希来判同构,则可以将所有子树中同构的分成一组。假设有 k 组,第 i 组大小为 cnt_i,则 f(u)=\frac{(siz_u-1)!\times \Pi_{v\in son_u}f(v)}{\Pi_{v\in son_u}siz_v!\times \Pi_{i=1}^k cnt_i!}

考虑统计答案,由于题意中给出的是无根树,所以要换根统计 f(u) 的和。注意到,同构的两棵树答案显然只用每统计一次,同时对树哈希换根即可。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N=5e5;
const int mod=998'244'353;
struct modint{
    int v;
    modint():v(0){}
    modint(int x):v(x){}
    modint operator+(modint o){
        int t=v+o.v;
        if(t>=mod){
            t-=mod;
        }
        return {t};
    }
    modint operator-(modint o){
        int t=v-o.v;
        if(t<0){
            t+=mod;
        }
        return {t};
    }
    modint operator*(modint o){
        return {1ll*v*o.v%mod};
    }
};
struct IO{
    int n;
    int u[N+2],v[N+2];
    void init(){
        cin>>n;
        for(int i=1;i<n;i++){
            cin>>u[i]>>v[i];
        }
    }
    modint ans;
    void output(){
        cout<<ans.v<<"\n";
    }
};
modint qpow(modint a,ll b){
    modint ans=1;
    while(b){
        if(b&1){
            ans=ans*a;
        }
        a=a*a;
        b>>=1;
    }
    return ans;
}
struct Comb{
    modint fac[N+2],inv[N+2];
    Comb(){
        fac[0]=1;
        for(int i=1;i<=N;i++){
            fac[i]=fac[i-1]*modint(i);
        }
        inv[N]=qpow(fac[N],mod-2);
        for(int i=N-1;i>=0;i--){
            inv[i]=inv[i+1]*modint(i+1);
        }
    }
};
struct solution{
    IO io;
    solution(IO io):io(io){}
    vector<int> e[N+2];
    void prework(){
        for(int i=1;i<io.n;i++){
            e[io.u[i]].push_back(io.v[i]);
            e[io.v[i]].push_back(io.u[i]);
        }
    }
    ull h(ull x){
        return x*x*123+x*567+8910;
    }
    ull g(ull x){
        return h(x&((1ull<<32)-1))+h(x>>32);
    }
    int siz[N+2];
    modint f[N+2];
    ull Hash[N+2];
    map<ull,int> mp[N+2];
    Comb c;
    void dfs1(int u,int fa){
        siz[u]=1;
        f[u]=1;
        Hash[u]=1;
        for(int v:e[u]) if(v!=fa){
            dfs1(v,u);
            siz[u]+=siz[v];
            f[u]=f[u]*f[v]*c.inv[siz[v]];
            Hash[u]+=g(Hash[v]);
            f[u]=f[u]*qpow(++mp[u][Hash[v]],mod-2);
        }
        f[u]=f[u]*c.fac[siz[u]-1];
    }
    map<ull,int> flag;
    void dfs2(int u,int fa){
        if(!flag[Hash[u]]){
            io.ans=io.ans+f[u];
            flag[Hash[u]]=1;
        }
        for(int v:e[u]) if(v!=fa){
            f[u]=f[u]*(mp[u][Hash[v]]--);
            f[u]=f[u]*qpow(f[v],mod-2)*c.fac[siz[v]]*c.inv[io.n-1]*c.fac[io.n-siz[v]-1];
            Hash[u]-=g(Hash[v]);
            Hash[v]+=g(Hash[u]);
            f[v]=f[v]*f[u]*c.inv[io.n-siz[v]]*c.inv[siz[v]-1]*c.fac[io.n-1];
            f[v]=f[v]*qpow((++mp[v][Hash[u]]),mod-2);
            dfs2(v,u);
            f[v]=f[v]*(mp[v][Hash[u]]--);
            f[v]=f[v]*qpow(f[u],mod-2)*c.fac[io.n-siz[v]]*c.fac[siz[v]-1]*c.inv[io.n-1];
            Hash[v]-=g(Hash[u]);
            Hash[u]+=g(Hash[v]);
            f[u]=f[u]*f[v]*c.inv[siz[v]]*c.fac[io.n-1]*c.inv[io.n-siz[v]-1];
            f[u]=f[u]*qpow((++mp[u][Hash[v]]),mod-2);
        }
    }
    void solve(){
        prework();
        dfs1(1,0);
        io.ans=0;
        dfs2(1,0);
        io.ans=io.ans*c.inv[io.n-1];
    }
};
int main(){
    // freopen("data.in","r",stdin);
    cin.tie(0)->sync_with_stdio(0);
    IO io;
    io.init();
    solution s(io);
    s.solve();
    s.io.output();
    return 0;
}