树上背包时间复杂度证明

· · 算法·理论

P2014 [CTSC1997] 选课

本文证明树上背包的时间复杂度是 \mathcal O(nm) 的。换言之,上题可以加强至:n \le 3 \times 10^5m \le 300

首先给出代码实现:

#include <bits/stdc++.h>
using namespace std;

#define dbg(...) cerr << "[" << #__VA_ARGS__ << "] = ", debug_out(__VA_ARGS__)
template <typename T> void debug_out(T t) { cerr << t << endl; }
template <typename T, typename... Ts> void debug_out(T t, Ts... ts) {
    cerr << t << ", ";
    debug_out(ts...);
}

template <class F> struct y_combinator {
    F f;
    template <class... Args> decltype(auto) operator()(Args &&...args) {
        return f(*this, std::forward<Args>(args)...);
    }
};
template <class F> auto make_y_combinator(F f) { return y_combinator<F>{f}; }

int main() {
    cin.tie(0)->sync_with_stdio(0);
    int n, m;
    cin >> n >> m;
    vector<int> a(n + 1);
    vector<vector<int>> e(n + 1);
    for (int i = 1, f; i <= n; i++) {
        cin >> f >> a[i];
        e[f].emplace_back(i);
    }
    ++m;
    vector<vector<int>> dp(n + 1, vector<int>(m + 1));
    vector<int> sz(n + 1);
    auto dfs = make_y_combinator([&](auto self, int u) -> void {
        dp[u][1] = a[u];
        sz[u] = 1;
        for (int v : e[u]) {
            self(v);
            for (int i = min(sz[u], m); i >= 1; i--) {
                for (int j = min(sz[v], m - i); j >= 1; j--) {
                    dp[u][i + j] = max(dp[u][i + j], dp[u][i] + dp[v][j]);
                }
            }
            sz[u] += sz[v];
        }
    });
    dfs(0);
    cout << dp[0][m] << endl;
    return 0;
}

让我们重点关注:

for (int i = min(sz[u], m); i >= 1; i--) {
    for (int j = min(sz[v], m - i); j >= 1; j--) {
        dp[u][i + j] = max(dp[u][i + j], dp[u][i] + dp[v][j]);
    }
}

这也就是说,对于每条边 u \to v,该部分转移的时间复杂度为

\mathcal O(\min(m, pre_v) \times \min(m, siz_v))

其中:

那么总时间复杂度即为

\mathcal O\left(\sum_{u \to v} \min(m, pre_v) \times \min(m, siz_v)\right)

接下来存在两个 自然的观察

  1. 考虑 \min(m, pre_v) \times \min(m, siz_v) \le m^2,由于边数 O(n),因此总时间复杂度不超过 \mathcal O(nm^2)

  2. 考虑 \min(m, pre_v) \times \min(m, siz_v) \le pre_v \cdot siz_v,这可以理解成在 u \to v 处对所有满足如下条件的 (x, y) 进行计数:

    注意到任意 (x,y) 仅会在它们的 \text{lca} 处被计数一次,于是有 \sum_{u\to v} pre_v \cdot siz_v \le n^2,因此总时间复杂度不超过 \mathcal O(n^2)

然未尽其析。

下证时间复杂度为 \mathcal O(nm)

我们将子树大小 \le m 的点称之为蓝点,子树大小 > m 的点称之为红点。

同时把边分为三类:

例如 n = 20m = 3 时,考虑下图:

接下来我们分四类讨论:

考虑所有蓝边,我们可以得到若干 极大蓝子树(图中有蓝点 6, 7, 9, 10, 11, 18, 19, 20 的子树)。根据上述 自然的观察 2,大小为 s 的子树内部对时间复杂度的贡献不超过 \mathcal O(s^2)。假设这些 极大蓝子树 的大小为 s_i,则:

这可以推出

\sum s_i^2 \le \sum s_i \cdot m \le nm

于是所有蓝边对时间复杂度的贡献不超过 \mathcal O(nm)

考虑所有黄边 u \to v,此时所有 v 的子树即为 极大蓝子树,它们是互斥的,有 \sum siz_v \le n,因此

\sum \min(m, pre_v) \times \min(m, siz_v) \le \sum m \cdot siz_v \le mn

于是所有黄边对时间复杂度的贡献不超过 \mathcal O(nm)

注意到仅红点和红边也构成一棵树,我们称之为红树。

考虑 红树的叶子节点,其个数为 \mathcal O\left(\dfrac nm\right),这是因为在原树中这些点的子树大小均 > m,且这些子树互斥。

仅考虑满足如下条件的红边 u \to v(让我们称之为 深红边):

(图中有 深红边 1\to 21 \to 32 \to 42 \to 5深红点 1,2

可以将每个 深红点 理解为,对至少两个 红树的叶子节点 进行合并。因此 深红点 的个数 < 红树的叶子节点 的个数,从而 深红边 的数量也为 \mathcal O\left(\dfrac nm\right)

根据上述 自然的观察 1,每条边对时间复杂度的贡献不超过 \mathcal O(m^2),因此所有 深红边 对时间复杂度的贡献不超过 \mathcal O\left(\dfrac nm \times m^2\right) = \mathcal O(nm)

最后考虑 浅红边 u \to v(即不是 深红边 的红边,图中有 3 \to 88 \to 13),此时所有 v 左边的子树 是互斥的,有 \sum pre_v \le n,因此

\sum \min(m, pre_v) \times \min(m, siz_v) \le \sum pre_v \cdot m \le nm

于是所有 浅红边 对时间复杂度的贡献不超过 \mathcal O(nm)

综上,树上背包时间复杂度不超过 \mathcal O(nm),这显然没法更低了,因此就是 \mathcal O(nm)

然犹未尽析。

有没有更直观一点的解释呢?

回到式子:

\min(m, pre_v) \times \min(m, siz_v)

考虑 dfs 序 dfn,仿照 自然的观察 2,这可以理解成在 u \to v 处对所有满足如下条件的 (x, y) 进行计数:

注意到任意 (x,y) 仅会在它们的 \text{lca} 处被最多计数一次,且只有当 dfn_y - dfn_x < 2m 才会被计数。这样的 (x,y)\mathcal O(nm) 对,因此总时间复杂度为 \mathcal O(nm)