题解:P11097 [ROI 2022] 采购优化 (Day 1)

· · 题解

原题链接

题意简述

给出一棵树,每个节点 i 有代价 c_i,并且给出区间 [l_i,r_i] 要求该节点子树内的权值和在该区间内。

要决定每个点的权值,使得满足给出的条件,并且每个节点的权值乘代价之和最小,给出一种方案。

解题思路

宝宝题,但是我假了 114514 发。

注意到多选点一定不优,所以满足最低要求就好了。考虑对于每个节点,其子树内节点对当前节点的贡献。

首先满足每个点的最低要求,显然如果存在一个点 i,使得其子节点满足最低要求之后的和超过 r_i,则无解。

设所有子节点为根的子树内满足条件后的权值之和为 sum_i,则要使 sum_i 满足条件必须多选出几个点,显然贪心选一定是正确的。

对于每个节点维护一坨三元组 \{val,cnt,id\} 表示在不超出子树内限制的情况下,可以添加 cnt 个代价为 val 的点,位置在 id

将子节点合并到父节点可以使用启发式合并,加上维持有序,开销是两只老哥。考虑怎么维护当前节点使得无论怎么选都不会超出子树内限制。

显然可以维护三元组使所有三元组的 cnt 之和恰好为 r_i-\max(l_i,sum_i)(注意这里要把代价为 c_i 的三元组也加进去),从后往前删去权值大的点直到满足条件即可。

复杂度分析

本题复杂度构成可以分成两个部分:

因此本题的时间复杂度为 O(n\log^2n)

参考代码

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll N = 100005;
ll T , n , fa[N] , c[N] , L[N] , R[N] , sum[N] , ans[N] , f[N] , ssum[N];
struct point
{
    ll val , num , id;
    bool operator <(const point &b) const
    {
        if(val != b.val) return val < b.val;
        if(num != b.num) return num < b.num;
        return id < b.id;
    }
};
set <point> s[N];
vector<ll> g[N];
bool flag;

void dfs(ll u)
{
    sum[u] = f[u] = ssum[u] = ans[u] = 0;
    for(auto v : g[u])
    {
        dfs(v); sum[u] += sum[v]; f[u] += f[v]; ssum[u] += ssum[v];
        if(s[v].size() > s[u].size()) swap(s[u] , s[v]);
        while(!s[v].empty()) s[u].insert(*s[v].begin()) , s[v].erase(s[v].begin());
    }
    if(sum[u] > R[u]) return flag = 1 , void();
    while(sum[u] < L[u])
    {
        if(!s[u].empty() && s[u].begin()->val < c[u])
        {
            auto [val , num , id] = *s[u].begin(); s[u].erase(s[u].begin());
            ssum[u] -= num;
            if(sum[u] + num > L[u])
            {
                ll tmp = L[u] - sum[u];
                num -= tmp; ans[id] += tmp;
                f[u] += tmp * val; sum[u] = L[u];
                s[u].insert({val , num , id});
                ssum[u] += num;
            }
            else
            {
                sum[u] += num; ans[id] += num;
                f[u] += num * val;
            }
        }
        else
        {
            ll tmp = L[u] - sum[u]; sum[u] = L[u];
            ans[u] += tmp; f[u] += tmp * c[u];
        }
    }
    ll tot = R[u] - sum[u];
    while(!s[u].empty() && prev(s[u].end())->val > c[u])
        ssum[u] -= prev(s[u].end())->num,
        s[u].erase(prev(s[u].end()));
    while(ssum[u] > tot && !s[u].empty())
    {
        auto it = --s[u].end(); s[u].erase(it);
        auto [val , num , id] = *it;
        if(ssum[u] - num < tot)
        {
            num -= ssum[u] - tot;
            ssum[u] = tot;
            s[u].insert({val , num , id});
        }
        else ssum[u] -= num;
    }
    if(ssum[u] < tot) s[u].insert({c[u] , tot - ssum[u] , u}) , ssum[u] = tot;
}

void solve()
{
    cin >> n; flag = 0;
    for(ll i = 2 ; i <= n ; i++) cin >> fa[i] , g[fa[i]].emplace_back(i);
    for(ll i = 1 ; i <= n ; i++) cin >> c[i];
    for(ll i = 1 ; i <= n ; i++) cin >> L[i] >> R[i];
    dfs(1); cout << (flag ? -1 : f[1]) << '\n';
    if(!flag) for(ll i = 1 ; i <= n ; i++) cout << ans[i] << ' '; cout << '\n';
    for(ll i = 1 ; i <= n ; i++) s[i].clear() , g[i].clear();
}

signed main()
{
    cin.tie(nullptr)->sync_with_stdio(false);
    cin >> T; while(T--) solve();
    return 0;
}