P6847 [CEOI2019] Magic Tree

· · 题解

提供一种别样的线段树合并写法。

\mathcal Magic

魔法树

给定一棵 n 个节点的树,有些节点上有果实,有值 d,w,分别表示果实成熟的日期和价值。

如果恰好在日期当天切掉该节点与根的连通,那么就能得到该价值。求能得到的价值的最大值。

转化一下题意,在树上选择一些点,使得所有被选择的点对 (u,v),若 v 在以 u 为根的子树中,那么有 d(v)\leq d(u)

考虑最朴素的树形 DP,设 f(u,i) 表示以 u 为根的子树,所有切断操作都在 \leq i 天内完成的最大价值。

那么对于 u 的一个儿子 v,它的贡献有:

  1. 不考虑加入 u 点的价值:f(u,i)=\sum f(v,i)
  2. 考虑加入 u 点的价值:f(u,j)=w(u)+\sum f(v,d(u)),其中 j\geq d(u)

因为两者是互斥关系,所以需要取个 \max

这样就得到了 O(nk) 的暴力算法,考虑使用数据结构优化它,不难想到线段树合并。

对于转移一可以直接合并,打上个区间加标价即可,对于二相当于是区间 [d(u),k] 对那个值取 \max

然后提供一个小 trick,不考虑使用普通线段树的区间赋值标记(比较难合并),而是记录下区间 \min,\max

如果一个区间的 \min=\max,那么这个区间所有值相等,这种情况下,

合并时直接被合并,更新时直接删掉左右儿子,修改时添加左右儿子,就相当于有了区间赋值的作用了。

时间复杂度 O(n\log n),若写个空间回收那么空间复杂度同样为 O(n\log n),可以参考下面的写法:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;

typedef long long LL;
const int N = 100010;
const int M = N * 40;

int n, m, k, tot, d[N], w[N], rt[N];
struct Tree{int l, r; LL mx, mn, add;} t[M];
vector<int> bac, G[N];

int read(){
    int x = 0, f = 1; char c = getchar();
    while(c < '0' || c > '9') f = (c == '-') ? -1 : 1, c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
    return x * f;
}

bool Same(int p){return (t[p].mx == t[p].mn);}

int New(){
    if(!bac.empty()){
        int p = bac.back();
        bac.pop_back();
        return p;
    }
    return ++ tot;
}

void Delet(int &x){
    if(!x) return;
    bac.push_back(x);
    t[x].l = t[x].r = 0;
    t[x].mx = t[x].mn = t[x].add = 0;
    x = 0;
}

void Plus(int p, LL v){
    if(!p) return;
    t[p].mx += v;
    t[p].mn += v;
    t[p].add += v;
}

void Push_Down(int p){
    if(!t[p].add) return;
    Plus(t[p].l, t[p].add);
    Plus(t[p].r, t[p].add);
    t[p].add = 0;
}

void Push_Up(int p){
    int l = t[p].l, r = t[p].r;
    t[p].mx = max(t[l].mx, t[r].mx);
    t[p].mn = min(t[l].mn, t[r].mn);
    if(t[p].mn == t[p].mx)
        Delet(t[p].l), Delet(t[p].r);
}

int Merge(int p, int q){
    if(!p || !q) return p + q;
    if(Same(q)){Plus(p, t[q].mx); return p;}
    if(Same(p)){Plus(q, t[p].mx); return q;}
    Push_Down(p), Push_Down(q);
    t[p].l = Merge(t[p].l, t[q].l);
    t[p].r = Merge(t[p].r, t[q].r);
    Push_Up(p);
    return p;
}

LL Query(int p, int l, int r, int k){
    if(Same(p)) return t[p].mx;
    int mid = (l + r) >> 1;
    Push_Down(p);
    if(k <= mid)
        return Query(t[p].l, l, mid, k);
    else
        return Query(t[p].r, mid + 1, r, k);
}

void Modify(int p, int l, int r, int L, int R, LL v){
    if(t[p].mn >= v) return;
    if(L <= l && r <= R && t[p].mx <= v){
        t[p].mx = t[p].mn = v; t[p].add = 0;
        Delet(t[p].l), Delet(t[p].r);
        return;
    }
    Push_Down(p);
    if(Same(p)){
        t[p].l = New(), t[p].r = New();
        int ls = t[p].l, rs = t[p].r;
        t[ls].mx = t[ls].mn = t[rs].mx = t[rs].mn = t[p].mx;
    }
    int mid = (l + r) >> 1;
    if(L <= mid)
        Modify(t[p].l, l, mid, L, R, v);
    if(R > mid)
        Modify(t[p].r, mid + 1, r, L, R, v);
    Push_Up(p);
}

void dfs(int u){
    rt[u] = New();
    for(int i = 0; i < (int) G[u].size(); i ++){
        int v = G[u][i];
        dfs(v);
        rt[u] = Merge(rt[u], rt[v]);
    }
    if(d[u]){
        LL val = Query(rt[u], 1, k, d[u]);
        Modify(rt[u], 1, k, d[u], k, val + w[u]);
    }
}

int main(){
    n = read(), m = read(), k = read();
    for(int v = 2; v <= n; v ++){
        int u = read();
        G[u].push_back(v);
    }
    for(int i = 1; i <= m; i ++){
        int p = read();
        d[p] = read(), w[p] = read();
    }
    dfs(1);
    printf("%lld\n", t[rt[1]].mx);
    return 0;
}