P9132 [USACO23FEB] Watching Cowflix P
其实我最先注意到的是如果两个关键点距离
这样说起来是很简单,但写起来还是有不少细节。代码中的实现方法是这样的:我们先求出虚树,然后预处理出每个点及每条边被合并的时间,以及每个点被某个联通块完全包裹在内部的时间。这样我们可以求出每个时刻 DP 前的连通块会新增多少个点,多少条边。用 set 维护当前剩下的点的 DFS 序以减小 DP 的常数,当一个点被某个连通块完全包裹在内部的时候就可以把它删掉了。具体的细节参见代码。
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5, inf = 2e9;
int n, par[N], val[N], dep[N]; bool key[N], era[N];
set <int> e[N];
void remove(int u) {
for (auto v : e[u]) {
par[v] = u, val[v] = 1, dep[v] = dep[u] + 1;
e[v].erase(e[v].find(u)), remove(v);
}
if (!key[u]) {
bool ok = 1;
for (auto v : e[u]) ok &= era[v];
if (ok) era[u] = 1;
}
}
int dis[N], _dis[N], tim[N], _tim[N], c[N], d[N], f[N][2], ans[N];
vector <int> buk[N];
struct dat {
int i;
bool operator < (const dat &p) const {
return dep[i] != dep[p.i] ? dep[i] > dep[p.i] : i > p.i;
}
};
set <dat> cur;
void dfs(int u) {
dis[u] = _dis[u] = (key[u] ? 0 : inf);
for (int v : e[u]) {
dfs(v);
int L = dis[v] + val[v];
if (L < dis[u]) _dis[u] = dis[u], dis[u] = L;
else if (L < _dis[u]) _dis[u] = L;
}
}
void calc(int u, int L) {
_tim[u] = dis[u] + L;
for (int v : e[u]) {
int _L = (dis[v] + val[v] == dis[u]) ? _dis[u] : dis[u];
calc(v, min(_L, L) + val[v]);
}
}
signed main() {
ios :: sync_with_stdio(false);
cin.tie(nullptr);
string s;
cin >> n >> s;
int root = -1;
for (int i = 1; i <= n; i++) if (s[i - 1] == '1') key[i] = 1, root = i;
for (int i = 1, x, y; i < n; i++) {
cin >> x >> y;
e[x].insert(y), e[y].insert(x);
}
remove(root);
for (int i = 1; i <= n; i++) if (era[i]) e[par[i]].erase(e[par[i]].find(i))
for (int i = 1; i <= n; i++) {
if (!era[i] && !key[i] && e[i].size() == 1) { int u = i, v = -1;
for (auto z : e[u]) v = z;
assert(v > 0);
par[v] = par[u], val[v] += val[u], e[par[u]].erase(e[par[u]].find(u)), e[par[u]].insert(v);
era[u] = 1;
}
}
dfs(root), calc(root, 0);
for (int i = 1; i <= n; i++) if (!era[i]) {
cur.insert((dat){ i });
tim[i] = key[i] ? 0 : _tim[i];
int t = _tim[i];
for (auto v : e[i]) tim[i] = min(tim[i], _tim[v]), t = max(t, _tim[v]);
buk[t].push_back(i);
c[tim[i]]++;
if (dep[i] != 0) c[_tim[i]] += val[i] - 1, d[_tim[i]] += val[i];
}
int C = c[0], D = 0;
for (int k = 1; k <= n; k++) {
for (int i : buk[k]) {
auto it = cur.find((dat){ i });
assert(it != cur.end());
cur.erase(it);
}
C += c[k], D += d[k];
ans[k] = k * (C - D) + C;
for (auto z : cur) {
int v = z.i;
if (tim[v] <= k) f[v][0] = inf;
else f[v][1] += k + 1;
if (_tim[v] <= k) ans[k] += f[v][1];
else {
int u = par[v];
f[u][0] += min(f[v][0], f[v][1]);
f[u][1] += min(f[v][0], min(f[v][1], f[v][1] - k + (val[v] - 1)));
}
f[v][0] = f[v][1] = 0;
}
}
for (int i = 1; i <= n; i++) cout << ans[i] << "\n";
return 0;
}