题解 CF1918F

· · 题解

更新:删除中文标点与英文之间的空格。

思路

显然遍历完每个叶子整个树也遍历完了,而且蹦床只会在叶子使用。

显然我们每次在遍历完一个子树后再遍历其它子树是最优的。无蹦床情况下,答案为 2n-2-\max dep(选择最深的点作为终点)。

考虑叶子 u 到叶子 v,如果不使用蹦床,dis(u,v)=dep_u+dep_v-2dep_{lca},如果使用,dis(1,v)=dep_v。此次使用蹦床就会使答案减少 dis(u,v)-dis(1,v)=dep_u-2dep_{lca}

发现 v 和答案无关,所以我们只需要考虑每个叶子通过哪一个祖先转向另一个叶子。当然,我们要先遍历完以该叶子到目标祖先中间所有点为根的子树。

要使答案最小,就要使前 k 大的 dep_u-2dep_{x} 最大(xu 的祖先),就要贪心地让大的 dep_u 和尽量的小的 dep_{x} 配对。

为什么这样贪心是对的呢?因为可选的 dep_{x} 数量与非终点的叶子结点数量相同(一个 x 可以选 s_x-1 次,其中 s_x 为儿子数),每个叶子都能和一个 dep_{x} 配对,所以不用担心选了 x 之后其它叶子没得选;因为这是一棵树,所以能替换 u 的所有点都有一个相同的“下位替代” t,对于同样被选择的,可以和 x 配对的两个叶子,谁去和“下位替代”配对是一样的。

也就是说,我们要将 dep 更大的 u 作为更大的子树的最后访问叶子。这样就能确定整个树的遍历顺序。然后将每次叶子之间的转移使用蹦床可以减少的代价放进优先队列,寻找前 k 大的代价即可。

代码

#include<bits/stdc++.h>
using namespace std;
int read() {
    int f = 1, x = 0;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-')f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return f * x;
}
void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9)write(x / 10);
    putchar(x % 10 + '0');
}
const int N = 2e5 + 10, MOD = 1e9 + 7, INF = 0x3f3f3f3f;
int fa[N], n = read(), k = read(), head[N], tot, dep[N], tmp[N], ans;
bool gg[N];
struct Edge {
    int to, nxt;
} e[N];
void add(int u, int v) {
    e[++tot].to = v;
    e[tot].nxt = head[u];
    head[u] = tot;
}
priority_queue<int>q;
void dfs(int pos) {
    dep[pos] = dep[fa[pos]] + 1;
    tmp[pos] = pos;
    for (int i = head[pos]; i; i = e[i].nxt) {
        dfs(e[i].to);
        if (dep[tmp[pos]] < dep[tmp[e[i].to]])tmp[pos] = tmp[e[i].to];
    }
}
void sfd(int pos) {
    for (int i = head[pos]; i; i = e[i].nxt) {
        sfd(e[i].to);
        if (tmp[pos] != tmp[e[i].to])q.push(-2 * dep[pos] + dep[tmp[e[i].to]]);
    }
}
signed main() {
    //freopen(".in", "r", stdin);
    //freopen(".out", "w", stdout);
    for (int i = 2; i <= n; i++) {
        fa[i] = read();
        add(fa[i], i);
    }
    dep[0] = -1;
    dfs(1);
    sfd(1);
    ans = 2 * n - 2 - dep[tmp[1]];
    while (!q.empty() && k) {
        if (q.top() <= 0)break;
        ans -= q.top();
        q.pop();
        k--;
    }
    write(ans);
    return 0;
}