P11363 [NOIP2024] 树的遍历 题解

· · 题解

前言

来点魔怔点分治做法。

考场上最后五分钟胡出 k=1 做法,没测大样例就结束了,最后发现一个地方没取模。在重做这题的时候模数又写成 998244353,我真无敌了。

观察性质

首先 k=1 的答案是 P=\prod_u (deg_u-1)!。这是因为,当我们第一次走到与点 u 相连的边时,我们还剩下 (deg_u-1) 条边可以走,而它们的顺序是可以任意排的。

接下来我们容斥,用 k\times P 减掉算重的。

我们先尝试研究,新图的 dfs 树有什么样的性质。

性质 1:原树中交于同一点的边在 dfs 树中必然是一条连续的链。

这是因为,如果 dfs 过程中从当前蓝边进入了其他子树,那么不可能通过它们回到蓝边,所以只能把其他蓝边作为当前蓝边在 dfs 树上的儿子继续遍历。

我们记这种链为“特殊链”。

这个性质告诉我们:一条边可以作为根生成某棵 dfs 树的充要条件是这条边是所有包含它的特殊链的端点;在交于同一点的边中,最多有两条可以同时作为根生成同一棵 dfs 树。

性质 2:若两条边可以作为根生成同一棵 dfs 树,那么原树中这两条边之间的所有边都可以作为根生成这棵 dfs 树。

这是因为,当蓝边作为根时,那么必然从蓝边进入 1 号点,从绿边进入 2 号点,从橙边进入 3 号点,也就是说,蓝边、绿边、橙边分别是 1,2,3 号点的特殊链的一个端点。同理,当红边作为根时,红边、橙边、绿边分别是 3,2,1 号点的特殊链的一个端点。红边,蓝边是我们钦定的根,而绿边,橙边同时作为包含它们的两条特殊链的端点,所以它们也是这棵 dfs 树的根。

因此,我们可以总结出,如果一个边集中的边可以同时作为根生成某棵 dfs 树,那么它们在原树中的虚树必须是一条链。因为假如不是,虚树中必定存在一个 \ge 3 度点,与“在交于同一点的边中,最多有两条可以同时作为根生成同一棵 dfs 树”矛盾。

\mathcal O(nk^2) 的做法

我们由此得到了一个 \mathcal O(nk^2) 的做法:

我们枚举一对关键边,满足它们之间没有其它关键边,计算同时以这一对边为根的 dfs 树的数量,然后减掉。这样做的正确性在于,考虑一棵 dfs 树,它能被 m 条关键边作为根生成出来,这 m 条关键边形成了一条链,我们一开始把这棵树算了 m 次,而这样做把它减掉了 (m-1) 次。

如何计算同时以某对边为根的 dfs 树的数量?设在这对边之间的点(不包括两个相距最远的点)构成的点集为 S,对于某个 u\in S,它的特殊链的两个端点都被固定,所以只能将相连的 (deg_u-2) 条边的顺序任意重排,因此答案为 \prod_{u\in S}(deg_u-2)!\prod_{u\notin S}(deg_u-1)!=\dfrac{P}{\prod_{u\in S}(deg_u-1)}

优化

我觉得 DP 写起来很麻烦啊!!!注意到我们要计算的问题是树上满足条件的路径的贡献之和,不难想到点分治。

考虑把当前重心 u 拉出来当根,计算所有包含 u 的路径的贡献。注意这里的路径为了不算漏,包括了这对边相距最远的那一对点,但是计算 \sum\dfrac{1}{deg-1} 的时候不要算上。

u 的儿子为 v_1\dots v_c,讨论 u 是否作为端点:

时间复杂度 \mathcal O(Tn\log n),我的点分治被卡常了,去掉找到重心后重新计算 siz 就过了(

#include<bits/stdc++.h>
bool Mst;
#define ll long long
#define pii pair<int,int>
#define fi first
#define se second
#define pb push_back
#define rep(x,qwq,qaq) for(int x=(qwq);x<=(qaq);++x)
#define per(x,qwq,qaq) for(int x=(qwq);x>=(qaq);--x)
using namespace std;
#define m107 1000000007
template<class _T>
void ckmax(_T &x,_T y) {
    x=max(x,y);
}
int ri() {
    int x;
    cin>>x;
    return x;
}
#define inf 0x3f3f3f3f
#define mod m107
template<int MOD>
struct modint{
    int x;
    modint(){x=0;}
    template<typename T>
    int norm(T y){return (y%MOD+MOD)%MOD;}
    template<typename T>
    modint(T y){x=norm(y);}
    friend modint operator +(modint a,modint b){return a.x+b.x;}
    friend modint operator -(modint a,modint b){return a.x-b.x;}
    friend modint operator *(modint a,modint b){return 1ll*a.x*b.x;}
    modint& operator +=(modint b){return x=norm(x+b.x),*this;}
    modint& operator -=(modint b){return x=norm(x-b.x),*this;}
    modint& operator *=(modint b){return x=norm(1ll*x*b.x),*this;}
    friend istream& operator >>(istream&is,modint &x){
        ll v;
        return is>>v,x.x=norm(v),is;
    }
    friend ostream& operator <<(ostream&os,modint &x){
        return os<<x.x,os;
    }
};
using mint=modint<m107>;
mint inv(int v) {
    auto exgcd=[&](auto &exgcd,int a,int b,int &x,int &y)->int {
        if(b==0)return x=1,y=0,a;
        int g=exgcd(exgcd,b,a%b,y,x);
        y-=a/b*x;
        return g;
    };
    int x,y,g=exgcd(exgcd,v,mod,x,y);
    assert(g==1);
    return x;
}
mint fac[101000],iv[101000];
void Solve_() {
    int n,k;
    cin>>n>>k;
    vector<int>deg(n+1);
    vector<pii>G(n+n+2);
    vector<int>bg(n+1),ed(n+1);
    {
        vector<vector<pii>>g(n+1);
        for(int i=1,u,v; i<=n-1; ++i) {
            cin>>u>>v;
            g[u].pb({v,i}),g[v].pb({u,i});
            ++deg[u],++deg[v];
        }
        int cnt=0;
        rep(i,1,n) {
            bg[i]=cnt+1,ed[i]=bg[i]+deg[i]-1;
            for(pii t:g[i])G[++cnt]=t;
        }
    }
    vector<int>e(n+1);
    rep(i,1,k)e[ri()]=1;
    mint P=1,ans=k;
    rep(i,1,n)P*=fac[deg[i]-1];
    ans*=P;
    vector<int>siz(n+1),vis(n+1);
    int rt,maxp;
    auto getrt=[&](auto &getrt,int u,int in,int s,int &maxp)->void {
        siz[u]=1;
        int nw=0;
        rep(j,bg[u],ed[u]) {
            int v=G[j].fi,i=G[j].se;
            if(i==in||vis[v])continue;
            getrt(getrt,v,i,s,maxp);
            siz[u]+=siz[v];
            ckmax(nw,siz[v]);
        }
        ckmax(nw,s-siz[u]);
        if(nw<maxp)rt=u,maxp=nw;
    };
    auto calc=[&](int u,int in)->mint {
        mint res=0;
        auto dfs=[&](auto &dfs,int u,int in,mint prd)->void{
            if(e[in])return void(res+=prd);
            prd*=iv[deg[u]-1];
            rep(j,bg[u],ed[u]) {
                int v=G[j].fi,i=G[j].se;
                if(i==in||vis[v])continue;
                dfs(dfs,v,i,prd);
            }
        };
        dfs(dfs,u,in,1);
        return res;
    };
    auto calc2=[&](int u)->mint {
        mint res=0;
        auto dfs=[&](auto &dfs,int u,int in,mint prd)->void{
            if(e[in])return void(res+=prd);
            prd*=iv[deg[u]-1];
            rep(j,bg[u],ed[u]) {
                int v=G[j].fi,i=G[j].se;
                if(i==in||vis[v])continue;
                dfs(dfs,v,i,prd);
            }
        };
        dfs(dfs,u,0,1);
        return res;
    };
    auto solve=[&](auto &solve,int u)->void {
        getrt(getrt,u,0,siz[u],maxp=inf);
        vis[u=rt]=1;
        mint sum=0,pre=0;
        rep(j,bg[u],ed[u]) {
            int v=G[j].fi,i=G[j].se;
            if(vis[v])continue;
            mint f=calc(v,i);
            sum+=pre*f*iv[deg[u]-1];
            pre+=f;
            if(e[i])sum+=calc2(v);
        }
        ans-=P*sum;
        rep(j,bg[u],ed[u]) {
            int v=G[j].fi;
            if(vis[v])continue;
            solve(solve,v);
        }
    };
    siz[1]=n;
    solve(solve,1);
    cout<<ans<<'\n';
}
bool Med;
signed main() {
    fac[0]=1;
    rep(i,1,100000)fac[i]=fac[i-1]*i,iv[i]=inv(i);
    cerr<<(&Mst-&Med)/1024.0/1024.0<<" MB\n";
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int c,Testcases=1;
    cin>>c>>Testcases;
    while(Testcases--)Solve_();
    return 0;
}