P9132 [USACO23FEB] Watching Cowflix P

· · 题解

其实我最先注意到的是如果两个关键点距离 \leq k,那么他们可以直接被合并。这样合并完之后最多只有 \mathcal{O}(\frac{n}{k}) 个关键点。然后发现好像还没最优呢,还可能在关键点的虚树上的某个非关键点处发生合并,那我们就求出虚树,然后树形 DP 一下,就做完了。因为只有 \mathcal{O}(\frac{n}{k}) 个关键点,所以虚树上最多也只有 \mathcal{O}(\frac{n}{k}) 个点,根据调和级数总时间复杂度 \mathcal{O}(n \log n)

这样说起来是很简单,但写起来还是有不少细节。代码中的实现方法是这样的:我们先求出虚树,然后预处理出每个点及每条边被合并的时间,以及每个点被某个联通块完全包裹在内部的时间。这样我们可以求出每个时刻 DP 前的连通块会新增多少个点,多少条边。用 set 维护当前剩下的点的 DFS 序以减小 DP 的常数,当一个点被某个连通块完全包裹在内部的时候就可以把它删掉了。具体的细节参见代码。

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5, inf = 2e9;
int n, par[N], val[N], dep[N]; bool key[N], era[N];
set <int> e[N];
void remove(int u) {
    for (auto v : e[u]) {
        par[v] = u, val[v] = 1, dep[v] = dep[u] + 1;
        e[v].erase(e[v].find(u)), remove(v);
    }
    if (!key[u]) {
        bool ok = 1;
        for (auto v : e[u]) ok &= era[v];
        if (ok) era[u] = 1;
    }
}
int dis[N], _dis[N], tim[N], _tim[N], c[N], d[N], f[N][2], ans[N];
vector <int> buk[N];
struct dat {
    int i;
    bool operator < (const dat &p) const {
        return dep[i] != dep[p.i] ? dep[i] > dep[p.i] : i > p.i;
    }
};
set <dat> cur; 
void dfs(int u) {
    dis[u] = _dis[u] = (key[u] ? 0 : inf);
    for (int v : e[u]) {
        dfs(v);
        int L = dis[v] + val[v];
        if (L < dis[u]) _dis[u] = dis[u], dis[u] = L;
        else if (L < _dis[u]) _dis[u] = L;
    }
}
void calc(int u, int L) {
    _tim[u] = dis[u] + L;
    for (int v : e[u]) {
        int _L = (dis[v] + val[v] == dis[u]) ? _dis[u] : dis[u];
        calc(v, min(_L, L) + val[v]);
    }
}
signed main() {  
    ios :: sync_with_stdio(false);
    cin.tie(nullptr);
    string s;
    cin >> n >> s;
    int root = -1;
    for (int i = 1; i <= n; i++) if (s[i - 1] == '1') key[i] = 1, root = i;
    for (int i = 1, x, y; i < n; i++) {
        cin >> x >> y;
        e[x].insert(y), e[y].insert(x);
    }
    remove(root);
    for (int i = 1; i <= n; i++) if (era[i]) e[par[i]].erase(e[par[i]].find(i))
    for (int i = 1; i <= n; i++) {
        if (!era[i] && !key[i] && e[i].size() == 1) { int u = i, v = -1; 
            for (auto z : e[u]) v = z;
            assert(v > 0);
            par[v] = par[u], val[v] += val[u], e[par[u]].erase(e[par[u]].find(u)), e[par[u]].insert(v);
            era[u] = 1;
        }
    }
    dfs(root), calc(root, 0);
    for (int i = 1; i <= n; i++) if (!era[i]) {
        cur.insert((dat){ i });
        tim[i] = key[i] ? 0 : _tim[i];
        int t = _tim[i];
        for (auto v : e[i]) tim[i] = min(tim[i], _tim[v]), t = max(t, _tim[v]);
        buk[t].push_back(i);
        c[tim[i]]++;
        if (dep[i] != 0) c[_tim[i]] += val[i] - 1, d[_tim[i]] += val[i];
    } 
    int C = c[0], D = 0;
    for (int k = 1; k <= n; k++) {
        for (int i : buk[k]) {
            auto it = cur.find((dat){ i });
            assert(it != cur.end());
            cur.erase(it);
        }
        C += c[k], D += d[k];
        ans[k] = k * (C - D) + C;
        for (auto z : cur) {
            int v = z.i;
            if (tim[v] <= k) f[v][0] = inf;
            else f[v][1] += k + 1;
            if (_tim[v] <= k) ans[k] += f[v][1];
            else {
                int u = par[v];
                f[u][0] += min(f[v][0], f[v][1]);
                f[u][1] += min(f[v][0], min(f[v][1], f[v][1] - k + (val[v] - 1)));
            }
            f[v][0] = f[v][1] = 0;
        }
    }
    for (int i = 1; i <= n; i++) cout << ans[i] << "\n";
    return 0; 
}