P9962 [THUPC 2024 初赛] 一棵树

· · 题解

*P9962 [THUPC 2024 初赛] 一棵树

谁说点分治过不了?

思路来自 asmend。

从边的角度计算贡献,涉及黑点个数之差的绝对值,难以入手。

考虑从点的角度计算贡献。钦定黑点重心为 x,以 x 为根,则每条边的贡献为 k - 2 b_i,其中 b_i 表示这条边的子树的黑点个数,从而知每个点涂成黑点的贡献为 -2d_i,其中 dep_i 表示点 i 的深度。

因为钦定 x 是黑点的重心,所以 x 的每棵子树不能选超过 m = \frac k 2 个黑点(所有除法均为下取整)。棘手的地方在于,如果某棵子树选择了超过 m 个黑点,那么会得到更优的答案,也就是不合法的方案带来了更优值,所以必须严格满足合法性(若不合法的方案不会更优,那么就算 x 最终不是重心也没关系)。

考虑最优方案对应的黑点重心 x ^ *,那么 x ^ * 不会有子树希望选超过 m 个黑点,否则调整可得更优解。相反,对于钦定的重心 x,如果 x 没有子树希望选超过 m 个点,那么它一定是最优方案之一对应的黑点重心之一。否则考虑更优方案在以 x 为根时对应的权值,根据刚才的分析,是小于它真实权值的,但最优方案的权值又小于更优方案在以 x 为根时对应的权值(这条性质由贪心策略保证),所以最优方案的权值小于更优方案的权值,矛盾。于是,我们希望找到一个点,使得以它为根贪心时没有子树希望选超过 m 个黑点。

k 为偶数时,最多有一棵子树希望选超过 m 个黑点。如果这样的子树存在,那么 x ^ * 显然落在这棵子树里,因为容易证明对于子树外的任意 x',以 x' 为根时 x 方向的子树希望选超过 m 个黑点。任意选点,每次选一棵子树递归,点分治即可。

时间复杂度 \mathcal{O}(n\log ^ 2 n),且常数较大。将排序换成桶排序即可做到 \mathcal{O}(n\log n)

k 为奇数时,由于奇偶性的限制,可能存在两棵子树希望选超过 m 个黑点。这也是唯一需要特殊处理的情况。如果这种情况没有发生,类似点分治即可。注意特判 k = 1

实际上,我们断言,k 为奇数时的答案就是 k - 1 时的答案额外加入一个新增权值最小的黑点。考虑在当前方案的基础上再选一个黑点 y 产生的贡献(减小的权值)。

综合上述两种情况,y 产生的贡献为它到所有 “子树恰有 m 个黑点” 的点的最短距离。于是每个白点 y 变成黑点后产生的贡献就是容易计算的了。相信现在读者的心中还剩下一个谜团 ——

为什么会这样?

首先考虑 k - 1 最优方案的重心 x,那么 k 的最优方案不会在 x 的某棵子树选超过 m + 1 个点,否则 k - 1x 在该方向上希望选至少 m + 1 个点,与 xk - 1 时的最优性矛盾。

如果 k 的最优方案没有在 x 的某棵子树选 m + 1 个点,那么 x 就是 k 的重心,根据贪心过程,k 的最优方案由 k - 1 的最优方案多选一个黑点得到。

否则 k 的最优方案在 x 的某棵子树 T 选了 m + 1 个点。此时 k - 1 的最优方案一定在 T 选了 m 个点,否则存在 y\notin T 使得 y 的深度不小于所有 T 中 “在 k 的最优方案被选中但没有在 k - 1 的最优方案被选中” 的点的深度,且 y 没有在 k 的最优方案被选中,调整即得更优解。

子树之间的贡献相对独立,容易证明 T 以外其它子树的决策不变。现在只需要考虑 T 内部的决策。

枚举所有 m + 1 个黑点的 LCA d,在 d 的限制下的最优方案为先选 d 子树内 m 个深度最大的点,再选一个让 LCA 固定为 d。设选点集合为 S,黑点的总贡献为 2\sum_{i\in S} dep_i - 2dep_d,我们要最大化总贡献。

一个容易理解的性质是 m 个深度最大的点可以任选:假设我们想选择的 m 个深度最大的点 S' 之中有一个点 x 没有被选中,那么 S 中至少有两个不是 S' 的元素。对于任意 x\in S,存在 y\in S 使得 x, y 的 LCA 等于 S 的 LCA(这是对于任意有根树和任意 S 都成立的),且 (S\backslash S') \cup (S\cap S') = S,所以

所有工具已经准备完毕,只差最后一击了!设原来 m 个黑点的 LCA 为 d'

综上,原来的 m 个黑点现在依然可以是黑点。

\square
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
using pdi = pair<double, int>;
using pdd = pair<double, double>;
using ull = unsigned long long;
using LL = __int128_t;

#define ppc(x) __builtin_popcount(x)
#define clz(x) __builtin_clz(x)

bool Mbe;
// mt19937 rnd(chrono::steady_clock::now().time_since_epoch().count());
mt19937_64 rnd(1064);
int rd(int l, int r) {
  return rnd() % (r - l + 1) + l;
}

// ---------- templates above ----------

constexpr int N = 5e5 + 5;

int n, k;
ll ans = LONG_LONG_MAX;
vector<int> e[N];

int R, mx[N], sz[N], vis[N];
void findr(int id, int ff, int tot) {
  sz[id] = 1, mx[id] = 0;
  for(int it : e[id]) {
    if(it == ff || vis[it]) continue;
    findr(it, id, tot);
    sz[id] += sz[it];
    mx[id] = max(mx[id], sz[it]);
  }
  mx[id] = max(mx[id], tot - sz[id]);
  if(mx[id] < mx[R]) R = id;
}

int dep[N], bel[N];
struct dat {
  int dep, id;
  bool operator < (const dat &z) const {
    return dep < z.dep;
  }
};
vector<dat> arr;
void findd(int id, int ff, int dp, int anc) {
  bel[id] = anc;
  arr.push_back({dep[id] = dp, id});
  for(int it : e[id]) {
    if(it == ff) continue;
    findd(it, id, dp + 1, anc ? anc : it);
  }
}

int f[N], cnt[N];
void dfs(int id, int ff) {
  for(int it : e[id]) {
    if(it == ff) continue;
    dfs(it, id), f[id] += f[it];
  }
}
void divide(int id) {
  vis[id] = 1;
  arr.clear();
  findd(id, 0, 0, 0);
  sort(arr.begin(), arr.end());
  vector<int> to;
  ll sum = 1ll * (n - 1) * k, lst = 0;
  memset(cnt, 0, N << 2);
  memset(f, 0, N << 2);
  for(int _ = 1; _ <= k; _++) {
    while(!arr.empty()) {
      int tid = arr.back().id;
      if((cnt[bel[tid]] + 1) * 2 <= k) break;
      else to.push_back(bel[tid]);
      arr.pop_back();
    }
    if(arr.empty()) break;
    sum -= lst = 2 * arr.back().dep;
    f[arr.back().id] = 1;
    cnt[bel[arr.back().id]]++;
    arr.pop_back();
    if(_ == k) ans = min(ans, sum);
  }

  sort(to.begin(), to.end());
  to.resize(unique(to.begin(), to.end()) - to.begin());
  if(to.size() <= 1) {
    for(int it : to) {
      if(vis[it]) continue;
      findr(it, id, n);
      R = 0, findr(it, id, sz[it]);
      divide(R);
    }
    return;
  }

  sum += lst, dfs(id, 0); // 撤回到 k - 1 的方案. 这里可以不清空 lst 对应 tid 的 f, 想一想为什么?
  static int dis[N], mx = 0;
  memset(dis, -1, N << 2);
  queue<int> q;
  for(int i = 1; i <= n; i++) {
    if(f[i] >= k / 2) dis[i] = 0, q.push(i);
  }
  while(!q.empty()) {
    int t = q.front();
    q.pop();
    if(!f[t]) mx = dis[t];
    for(int it : e[t]) {
      if(dis[it] == -1) {
        dis[it] = dis[t] + 1;
        q.push(it);
      }
    }
  }
  ans = min(ans, sum - 2 * mx);
}

void solve() {
  cin >> n >> k;
  if(k == 1) {
    cout << n - 1 << "\n";
    return;
  }
  for(int i = 1; i < n; i++) {
    int u, v;
    cin >> u >> v;
    e[u].push_back(v);
    e[v].push_back(u);
  }
  mx[0] = N, findr(1, 0, n);
  divide(R);
  cout << ans << endl;
}

bool Med;
int main() {
  fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  #ifdef ALEX_WEI
    FILE* IN = freopen("1.in", "r", stdin);
    FILE* OUT = freopen("1.out", "w", stdout);
  #endif
  int T = 1;
  while(T--) solve();
  fprintf(stderr, "%d ms\n", int(1e3 * clock() / CLOCKS_PER_SEC));
  return 0;
}

/*
g++ a.cpp -o a -std=c++14 -O2 -DALEX_WEI
*/