P10309 「Cfz Round 2」Max of Distance 神奇做法题解

· · 题解

是很有趣的构造题!膜拜出题人!在这里给出一种不太一样,但是比较好想的做法。

首先,我们先不管 E 的限制,假设所有的边权都为 1,试着求出这个期望:

由于我们知道到树上一点最远的点,一定是直径的一个端点,我们就可以先找出一条直径,然后从直径的两个端点开始分别 dfs 一遍,求出每个点的 \max\limits_{v=1}^n \mathbb{dis}(u,v),从而算出此时的期望。

如果我们给每条边赋上不同的边权,那么直径可能会变化,我们原本求出的这个期望就没有用了。但是,如果所有边的边权都相同的话,我们原本求出的直径,和现在的直径就是相同的!而且,假设所有边的边权都为 x,所有边权都为 1 的期望为 y,那么现在的期望就是 xy!因为原本的期望其实就是现在的边数期望,乘上边权就是答案的期望了。

那么当 y\ne 0 时,只要令所有边的边权都为 \frac{E}{y},就可以符合条件了。但是若 y=0,怎么办呢?

其实问题也不大。首先可以发现此时树不会是一条链(当 1\leq n\leq 10^5 时,没有任何一种长度的链,可以使得期望为 0);也就是说,至少有三个叶子节点;也就是说,一定存在不是直径端点的叶子节点

那么我们只要把与那个叶子节点相连的边的边权改成 \bmod 即可,这样取模后就为 0,这个点的 \max\limits_{v=1}^n \mathbb{dis}(u,v) 就会 -1,期望就不会是 0 了,就可以按原本的方法做了。

总时间复杂度 O(n),如果没有特判的话还是非常短的。

int n, E, u[N], v[N], p1, p2, dis[N], res[N], sum;
vector<int> e[N];

void dfs(int u, int fa)
{
    for (auto j : e[u])
        if (j ^ fa)
            dis[j] = dis[u] + 1, dfs(j, u);
}

int qpow(int x, int k)
{
    int res = 1;
    while (k)
    {
        if (k & 1)
            res = 1ll * res * x % mod;
        x = 1ll * x * x % mod, k >>= 1;
    }
    return res;
}

int main()
{
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n;
    for (int i = 1; i < n; i++)
        cin >> u[i] >> v[i], e[u[i]].push_back(v[i]), e[v[i]].push_back(u[i]);
    cin >> E;
    dfs(1, 0);
    for (int i = 1; i <= n; i++)
        if (dis[i] > dis[p1])
            p1 = i;
    dis[p1] = 0, dfs(p1, 0);
    for (int i = 1; i <= n; i++)
    {
        res[i] = max(res[i], dis[i]);
        if (dis[i] > dis[p2])
            p2 = i;
    }
    dis[p2] = 0, dfs(p2, 0);
    for (int i = 1; i <= n; i++)
        res[i] = max(res[i], dis[i]), sum = (sum + res[i]) % mod;
    if (!sum)
    {
        int tag;
        for (int i = 1; i < n; i++)
            if (e[u[i]].size() == 1 && u[i] != p1 && u[i] != p2)
            {
                tag = i;
                break;
            }
            else if (e[v[i]].size() == 1 && v[i] != p1 && v[i] != p2)
            {
                tag = i;
                break;
            }
        E = 1ll * n * E % mod * qpow(mod - 1, mod - 2) % mod;
        if (!E)
            E = mod;
        for (int i = 1; i < n; i++)
            if (i != tag)
                cout << E << "\n";
            else
                cout << mod << "\n";
        return 0;
    }
    E = 1ll * n * E % mod * qpow(sum, mod - 2) % mod;
    if (!E)
        E = mod;
    for (int i = 1; i < n; i++)
        cout << E << "\n";
    return 0;
}