题解:P11363 [NOIP2024] 树的遍历

· · 题解

菊花图且 k=1:所有的边都相邻,所以答案为 (n-2)!

:边的邻接关系同样构成链,所有只有一种生成树,答案为 1

k=1:观察下图的关键边和一种可能的连接方案,考虑每个点连接的边集,令其中最先访问的边为 e,则该点的邻边的连接情况一定是某种以 e 为端点的链,由 dfs 生成树无横叉边可证,而且 (\deg-1)! 种都可以取到。答案则是 \prod_{i}(\deg i-1)!

k=2:我们考察某条起始边的意义,令该边的一个端点为根,如上图所示,根以外的每个点首先被访问的邻边为它连向根的边,即邻边集的连接情况对应的链一定有一个端点是连向根的边。

考虑有多少种答案同时满足两个起始边的条件,令 X,Y 为各自的一个端点,且 X\ne Y,考虑取出树链 X-Y,在链之外的点的情况是平凡的,因为无论以 X 还是 Y 为根,朝向根的邻边不变。链上的中间点则需要保证链的两个端点恰好为朝向 XY 的,方案数由 (\deg-1)! 变为 (\deg-2)!。对于 XY,则需要根据它们代表的起始边的另一个端点的位置特殊讨论。

正解:可以先以 1 为根转化为有根树,对于每条关键边,可以用深度较大的端点唯一表示,下面可能会用“起始点”来描述对应的“起始边”。如果一组连接情况同时满足多个起始点,则它们一定在一条链上,否则会出现下图的不合法情况。

容易发现中间点处无解。

于是我们可以考虑树上 DP。先进行容斥,令 g(s) 为钦定了同时满足集合 s 中的起始点的方案数,答案显然是 \sum_{s}(-1)^{|s|+1}g(s)。故设计 DP 状态 f(u) 表示钦定的关键点集构成的链经过 u,且其中一个端点在 u 子树内的带权方案数,为了方便转移,可以对于每个点的方案都除以 (\deg-1)!。令 c=\frac{1}{\deg u-1}(当 \deg u=1c=1)。则转移为:

f(u)=\left\{ \begin{aligned} &\sum_{v}cf(v) \ \ \ \ (u\text{是某个起始点}) \\ &-1 \ \ \ \ else \end{aligned} \right.

贡献答案则是枚举两棵子树,或者当前点为起始点连向子树内某个点,是容易的。

最后答案应该乘以 \prod_{i}(\deg i-1)

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
constexpr int mod = 1000000000 + 7;
constexpr int maxn = 100000 + 10;
i64 qp(i64 a, i64 b)
{
    i64 c = 1;
    for (; b; b>>=1, a=a*a%mod)
        if (b & 1) c=c*a%mod;
    return c;
}
i64 fac[maxn], inv[maxn];
vector<int> g[maxn];
i64 f[maxn];
int o[maxn];
i64 ans;
void dp(int u, int fa)
{
    i64 s = 0, p = 0;
    for (int v : g[u]) if (v != fa)
    {
        dp(v, u);
        p = (p + s * f[v]) % mod;
        s = (s + f[v]) % mod;
    }
    i64 c = g[u].size() == 1 ? 1 : inv[g[u].size() - 1] * fac[g[u].size() - 2] % mod;
    ans = (ans - p * c % mod + mod) % mod;
    if (o[u])
    {
        f[u] = mod - 1;
        for (int v : g[u]) if (v != fa)
        {
            ans = (ans + f[v] * c % mod + mod) % mod;
        }
    }
    else f[u] = s * c % mod;
}
int d[maxn];
void dfs(int u, int fa)
{
    d[u] = d[fa] + 1;
    for (int v : g[u]) if (v != fa) dfs(v, u);
}
int u[maxn], v[maxn];
void solve()
{
    int n, k;
    cin >> n >> k;
    for (int i=1;i<=n;++i) g[i].clear();
    memset(f, 0, sizeof(f));
    memset(o, 0, sizeof(o));
    for (int i=1;i<n;++i)
    {
        cin >> u[i] >> v[i];
        g[u[i]].emplace_back(v[i]);
        g[v[i]].emplace_back(u[i]);
    }
    dfs(1, 0);
    for (int i=1;i<=k;++i)
    {
        int x;
        cin >> x;
        o[d[u[x]] < d[v[x]] ? v[x] : u[x]] = 1;
    }
    ans = k;
    dp(1, 0);
    for (int i=1;i<=n;++i) ans = ans * fac[g[i].size() - 1] % mod;
    cout << ans << '\n';
}
int main()
{
    fac[0] = 1;
    for (int i=1;i<maxn;++i) fac[i] = fac[i - 1] * i % mod;
    inv[maxn - 1] = qp(fac[maxn - 1], mod - 2);
    for (int i=maxn-2;i>=0;--i) inv[i] = inv[i + 1] * (i + 1) % mod;
    ios::sync_with_stdio(0);
    cin.tie(0);
    int c, t;
    cin >> c >> t;
    while (t--) solve();
    return 0;
}