题解:P12136 [蓝桥杯 2025 省 B] 生产车间

· · 题解

能意识到 G 有多难就能意识到 G 有多简单。

本文将介绍 P12136 [蓝桥杯 2025 省 B] 生产车间 与子集和问题之间的关联,并给出基于此得出的做法,以及可能的时间优化方式。

题意简述

给定一棵有根树,忽略若干个节点使得对于剩下的非叶节点 u,它所包含的剩下的叶节点权值之和都 \le w_u;并对于根节点 1,最大化它所包含的剩下的叶节点权值之和。

题目分析

树上问题,自然想到要用树形 dp 来做。然后注意到数据范围 n \le 1,000w \le 1,000,应该能意识到这题只要求平方时间的做法。

这意味着,或许,这个题很难,使得更低复杂度的做法可能是不存在的。但到底有多难?我们可以拿另一个已经有明确难度的题来比较。

与子集和问题的联系

考虑一种比较极端的数据:除了 1 之外的所有节点全都是叶节点,编号构成的集合为 L = \{2, 3, \dots, n\}。此时问题转化成找出一个子集 T \sube L 使得 \sum_{v \in T} w_v \le w_1\sum_{v \in T} w_v 最大。

如果你解决了该问题,那么通过判定 \sum_{v \in T} w_v 是否等于 w_1,你就直接判定了 L 是否存在一子集,它的权值和为 w_1。而这就是著名的子集和问题,在值域没有限制的时候是 NP-Complete 的。

综上,该问题是严格不弱于子集和问题的,如果 w 没有限制,n \le 1000 这样大的范围是难以短时间做出来的。

小值域子集和问题

现在应该就能看出 w \le 1,000 的作用了。子集和问题虽然是 NP-Complete 的,但是其存在 O(n\sum w) 的伪多项式时间的做法。对于本题而言,w 是限制在每一个节点上的,从而对每一个节点,它的子集和的范围都不过 O(w) 而已。

到了这里,应该就能够得出 O(nw^2) 的做法了:

S_v = (S_{r_1} + S_{r_2} + \dots + S_{r_{d^+(v)}}) \cap [0, w_v]

其中 S + T = \{s + t \mid \forall s \in S, \forall t \in T\}r_iv 的第 i 个子节点。

最后对于 S_1,输出 \max S_1 即可。

可能的优化

尽管 O(nw^2) 很有可能跑不满但仍然是可达到的,一旦达到了就可能超时。由于瓶颈在子集和问题上且不可能比子集和问题简单,唯一的方法就是优化解决子集和问题的算法。

可以使用 C++ 的 std::bitset 来优化常数,此时单次合并子集和可以优化到 O\bigg(\dfrac{w^2}{\omega}\bigg),其中 \omegastd::bitset 所用的字长,一般为 32 或 64。总时间复杂度 O\bigg(\dfrac{nw^2}{\omega}\bigg)

另一种更狠的优化方式超出了本题的难度范围,本人赛场上懒得想了就直接写了这个:使用 FFT 优化单次合并子集和,可以做到 O(w \log w)。对于有兴趣的,这里讲解一下:

给定子集和 S,T 生成 S + T,你要优化的本质上是生成 s + t 的过程。对 S 构造生成函数 F_S(z) = \sum_{s \in S} z^s,此时有:

F_S(z)F_T(z) = \sum_{s \in S} z^s \sum_{t \in T} z^t = \sum_{s \in S, t \in T} z^{s + t}

从而 F_S(z)F_T(z)z^k 的系数表示有多少对 s, t 满足 s + t = k,若系数 \ge 1 则表示存在这样的 s, t

而计算 F_S(z)F_T(z) 显然就是多项式乘法,使用 FFT 就可以做到 O(w \log w) 的时间复杂度。此时本题的时间复杂度为 O(nw \log w),时间绰绰有余。

参考代码

#include <algorithm>
#include <array>
#include <iostream>
#include <vector>

using std::cin;
using std::cout;
using std::vector;

using u64 = unsigned long long;

constexpr u64 P = 998'244'353;

template<class Grp, class Exp, class Ty>
Ty pow(Grp base, Exp exp, Ty id) {
    for(; exp; exp >>= 1, base *= base)
        if(exp & 1)
            id *= base;
    return id;
}

template<class Grp, class Exp, class Ty>
Ty& powass(Grp base, Exp exp, Ty &id) {
    for(; exp; exp >>= 1, base *= base)
        if(exp & 1)
            id *= base;
    return id;
}

struct Fp {

    u64 x;

    constexpr Fp(u64 val = 0) : x(val) {}

    Fp& operator+=(const Fp &other) {
        x += other.x;
        if(x >= P) x -= P;
        return *this;
    }
    Fp& operator-=(const Fp &other) {
        if(x < other.x) x += P;
        x -= other.x;
        return *this;
    }
    Fp& operator*=(const Fp &other) {
        x = x * other.x % P;
        return *this;
    }
    Fp& operator/=(const Fp &other) {
        return powass(other, P - 2, *this);
    }
    friend Fp operator+(Fp lhs, const Fp &rhs) { return lhs += rhs; }
    friend Fp operator-(Fp lhs, const Fp &rhs) { return lhs -= rhs; }
    friend Fp operator*(Fp lhs, const Fp &rhs) { return lhs *= rhs; }
    friend Fp operator/(Fp lhs, const Fp &rhs) { return lhs /= rhs; }
};

Fp prim(3);
Fp coprim = Fp(1) / prim;

vector<u64> graph[1001];
vector<u64> subset[1001];
bool visited[1001] = {};
u64 w[1001] = {};

template<class RanIt>
void FFT(RanIt arr, bool type) {
    for(int logseg = 0; logseg < 11; ++logseg) {
        size_t seg = 1ull << logseg;
        Fp omega = pow(type ? coprim : prim, (P - 1) >> (logseg + 1), Fp(1));
        for(RanIt lhs = arr; lhs != arr + 2048; lhs += (seg + seg)) {
            RanIt rhs = lhs + seg;
            Fp phi = Fp(1);
            for(size_t off = 0; off < seg; ++off) {
                Fp g = lhs[off];
                Fp hw = rhs[off] * phi;
                lhs[off] = g + hw;
                rhs[off] = g - hw;
                phi *= omega;
            }
        }
    }
    if(type) {
        Fp div = Fp(1) / Fp(2048);
        for(size_t i = 0; i < 2048; ++i)
            arr[i] *= div;
    }
}

vector<u64> subsetmerge(u64 u, const vector<u64> &lset, const vector<u64> &rset) {
    if(lset.size() * rset.size() < 2048) {
        bool contains[2001] = {};
        for(u64 x : lset) {
            for(u64 y : rset) {
                contains[x + y] = true;
            }
        }
        vector<u64> res;
        for(u64 z = 0; z <= w[u]; ++z)
            if(contains[z])
                res.push_back(z);
        return res;
    }
    std::array<Fp, 2048> lff{}, rff{};
    for(u64 x : lset) lff[x] = Fp(1);
    for(u64 x : rset) rff[x] = Fp(1);
    FFT(lff.begin(), false);
    FFT(rff.begin(), false);
    for(size_t i = 0; i < 2048; ++i)
        lff[i] *= rff[i];
    FFT(lff.begin(), true);
    vector<u64> res;
    for(u64 z = 0; z <= w[u]; ++z)
        if(lff[z].x)
            res.push_back(z);
    return res;
}

void dfs(u64 u) {
    visited[u] = true;
    bool isleaf = true;
    for(u64 v : graph[u]) {
        if(!visited[v]) {
            isleaf = false;
            dfs(v);
            subset[u] = subsetmerge(u, subset[u], subset[v]);
        }
    }
    if(isleaf) {
        subset[u].push_back(w[u]);
    }
}

int main() {

    size_t n;
    cin >> n; 
    for(u64 u = 1; u <= n; ++u) {
        cin >> w[u];
        subset[u].push_back(0);
    }
    for(u64 i = 1; i < n; ++i) {
        u64 u, v;
        cin >> u >> v;
        graph[u].push_back(v);
        graph[v].push_back(u);
    }
    dfs(1);
    cout << subset[1].back();
    return 0;
}