题解:P12481 [集训队互测 2024] 运筹帷幄/ ARIS0_0 - 38

· · 题解

这是一道树形 DP,但难点不在 DP,而在维护转移与换根。核心数据结构是数组和指针,维护思路相当巧妙。

1. 简单贪心

假设当前处于子树 u,我们需要通过移动子树内的棋子,使得子树内棋子到 u 的最小距离和最小。假设 u 所有儿子的子树已经按照最优方案调整好,现在可以调整 a_u-b_u 个棋子到 u。显然,将当前子树内离 u 距离最远的 a_u-b_u 个棋子挪到 u 更优。具体实现可以用树形 DP。

2. 朴素 DP

我们设 f_{u,i},g_{u,i} 分别表示经过最优调整后,节点 u 的子树中,距离大于 i 的棋子个数与距离总和。

易得转移方程:

f_{u,i}\gets f_{u,i}+f_{v,i - 1}\\ g_{u,i}\gets g_{u,i-1}+f_{v,i-1}

合并完所有儿子的贡献后,暴力枚举到一个位置 k,满足 f_{u,k}\ge a_u-b_u,所有距离大于 k 与部分距离等于 k 的棋子都将移动到 u,暴力更新即可。

随后对每个根都跑一遍 DP,单次 DP 的时间复杂度为 O(n^2),总时间复杂度为 O(n^3),显然不能接受。

3. 初步优化

朴素 DP 很明显可以用长剖优化,这样可以将空间复杂度优化至 O(n)。对于长儿子贡献的继承,可以维护偏移指针,O(1) 转移长儿子,转移方程不变。对于最优棋子调整,由于 f 具有单调性,因此可以通过二分来定位 k,随后维护减法标记做全局减法并截断删除部分。

具体来说,对于每个长链,我们需要维护 DP 数组 f,g,以及指针 l,r 与减法标记 tag

对于叶子节点,我们令其所在的长链的 l=r=len,其中 len 为长链长度,到子树的根的距离 i 就需要用 i+l 来表示了。对于长儿子的继承,可先令 l 左移,随后按方程转移即可,短儿子则可以正常转移。

至于最优棋子调整,先二分找到 k,然后令 r\gets k,更新减法标记为删除的棋子总数,这样 f,g 数组的值都不会被修改。到子树的根距离为 if 的实际值就为 f_{i+l}-tag,所有棋子到根的最小距离和就为 g_{l}-g_{r+1}-(r-l+1)\times tag

这样我们就把单次 DP 优化到了 O(n\log n),不过这玩意不好换根啊,单次消耗 O(n) 暴力重构节点信息的话,换根的复杂度就飙到 O(n^2\log n) 了,这和对每个根都跑一遍没区别,需要进一步优化。

4. 最终优化

长剖在换根下不好维护,于是我们把目光转向重剖,重剖显然也可以优化朴素 DP,且时间复杂度依旧是 O(n\log n)。至于空间,对于每条重链,其要维护的 DP 数组长度需要开到链顶子树高度的大小,因为长链与重链不一定一致,轻儿子是数组长度要是比重儿子大那就爆掉了。可以用数学归纳法证明,这么开数组的总长度依旧是 O(n) 的。

相较于长剖,重剖具有更多优秀的性质,可以帮助我们解决这个问题。

我们需要维护三个信息,分别是重链信息、近棋子集合和远棋子集合。

当我们完成单根 DP 时,我们就求出了所有重链信息。然后考虑换根,我们先考虑怎么求答案,假设目前我们有近棋子集合、远棋子集合与重儿子的信息,这些信息恰好覆盖了整棵树,现在考虑怎么移动棋子最优。首先要优先删远棋子集合,即移动远棋子集合中的棋子。若删完远棋子集合,再考虑重儿子和近棋子集合,若未删完,则删除其一部分棋子,做一个修改即可,删除方式均为二分。但这三者的距离范围可能有交集,需要揉一块二分,实现起来还是比较复杂的。

接下来考虑怎么手搓一个数据结构维护这三个玩意,让我们来构思一下这个神秘的结构:首先,它可以动态分配内存,根据其要维护的距离范围开数组;其次,它拥有两个指针来维护有效数组范围,且可以高效访问数组内部元素;最后,它需要是一个可撤销数据结构。

于是我们自然就想到了动态数组与指针。我们开一个结构体,分别维护动态数组 f,g、指针 l,r、减法标记 tag 以及一个记录历史状态的回滚栈。它需要实现以下功能:

  1. 求出其距离范围内的最小距离和
  2. 求出到根距离为 x 的棋子个数
  3. 求出有效数组长度
  4. 回滚至某一时刻的历史状态
  5. 拓展一个新节点
  6. 按贪心思路删除 x 个棋子
  7. 合并、消除相邻信息

虽然要实现的东西有点多,但是代码还是比较好写的:

:::success[核心数据结构]{open}

struct DS{ 
    ll tag = 0 , l , r;  // 删除标记与指针
    vector<ll> f , g; // DP 数组,分别记录棋子数目,距离总和
    vector<pair<ll* , ll>>sta; // 回滚栈
    DS() {}
    DS(ll tag , int l , int r , int n) {
        this->tag = tag , this->l = l , this->r = r , this->f.resize(n + 5) , this->g.resize(n + 5);
    } 
    ll sum() { // 这是移动棋子后的最小距离和,用于计算答案
        return g[l + 1] - g[r] - (r - l - 1) * tag;
    }
    ll qry(int x) { // 返回到根距离为 x 的棋子个数
        if(l + x < r) return f[l + x] - tag;
        else return 0;
    }
    int size() { // 有效数组长度
        return r - l + 1;
    }
    void recall(int x) { // 回滚至时刻 x 的历史状态
        while(sta.size() > x) {
            (*sta.back().first) = sta.back().second;
            sta.pop_back();
        }
    }
    void insert(ll x) { // 插入一个元素,并偏移指针 l
        sta.push_back({&l , l --}); // 记录到回滚栈中,方便撤销操作
        sta.push_back({&f[l] , f[l]});
        sta.push_back({&g[l] , g[l]});
        f[l] = f[l + 1] + x , g[l] = g[l + 1] + f[l]; // 转移 
    }
    void erase(ll x) { // 删除 x 个棋子,并偏移指针 r,即移动棋子到当前根
        if(size() <= 1 || x == 0) return;
        sta.push_back({&tag , tag}) , tag += x; // 更新删除的棋子数目
        ll it = lower_bound(f.begin() + l , f.begin() + r , tag , greater<ll>()) - f.begin(); // 二分
        if (r != it) sta.push_back({&r , r}) , r = it;
    }
    void merge(DS o , ll t) { // 合并/消除相邻信息
        if (l + o.size() > r) sta.push_back({&r , r});
        for (int i = l + o.size() - 1 ; i >= l; i --) {
            sta.push_back({&g[i] , g[i]});
            sta.push_back({&f[i] , f[i]});     
            if(i >= r) f[i] = tag; // 超出部分的实际值为 0,设为 tag
            f[i] += o.qry(i - l) * t;   
            g[i] = g[i + 1] + f[i];
        }
        r = max(r , l + o.size() - 1); // 更新数组长度
    }
} F(0 , N , N + N , N + N) , dp[N]; // 近棋子集合、重链信息

::: 该结构可以维护近棋子集合与重链信息,但远棋子集合不需要这么复杂的结构,考虑到其设计需要完整的换根思路,我们先讲完换根部分再设计它。

所以现在我们可以思考怎样具体实现换根了!

先求以当前节点为根的答案:首先我们调整重链为仅处理完重儿子的状态,直接回滚就可以了,随后,将轻儿子合并到近棋子集合,两者距离差为 1,属于相邻信息,可以直接合并。处理完这些后,将重链插入远棋子集合,随后根据当前节点上的棋子数,转移近棋子集合拓展距离为 0 的状态。此时就只需按之前的思路移动棋子,先删远棋子集合,再删近棋子集合,然后计算这两个集合所有元素保存的距离和就可以了。

注意完成当前根的求解后要撤销棋子的移动,恢复远近棋子集合到原来的状态。

接下来考虑给轻儿子换根:首先将轻儿子从近棋子集合中移除,随后拆重链,将重儿子的信息划分成两个部分,距离在轻儿子子树高度范围内的插入近棋子集合,超出部分则插入远棋子集合,随后转移近棋子集合拓展距离为 0 的状态,重新按之前的思路移动棋子。这样就处理完毕了,递归轻儿子,然后恢复到原来状态即可。

重儿子则没有那么麻烦,拓展近棋子集合、移动棋子后直接递归就可以了,注意最后要将远近棋子集合恢复到最开始的状态。

具体思路就是这样,但是你可以发现,近棋子集合只需要维护一个结构体就可以了,但远棋子集合不一样,它里面的元素是若干个从重链上拆分出来的片段,且它们表示的最近距离到当前根的距离是按照插入时间升序排列的,因此它是一个线性数据结构,需要实现以下功能:

  1. 求出片段中所有棋子到其截取位置的最小距离和
  2. 求出这片段中的棋子个数
  3. 求出到截取位置距离为 x 的棋子个数
  4. 复制某个重链信息的一段前缀
  5. 按贪心思路结合近棋子集合删除若干个棋子
  6. 先进先出的结构,可以遍历内部元素

:::success[辅助数据结构]{open}

struct node{ 
    ll tag , l , r;
    vector<ll> *f , *g; // 指针数组,直接指向复制目标的地址,省空间
    node() {}
    node(ll tag , int l , int r , vector<ll> *f , vector<ll> *g) {
        this->tag = tag , this->l = l , this->r = r , this->f = f , this->g = g;
    } 
    ll sum() { // 计算这一部分的贡献
        return (*g)[l] - (*g)[r + 1] - (r - l + 1) * tag;
    }
    ll cnt(){ // 返回这一部分的棋子个数
        if(l <= r) return (*f)[l] - tag;
        else return 0;
    }
    ll qry(int x) { // 返回到截取位置距离为 x 的棋子个数
        return (*f)[l + x] - tag;
    }
    ll erase(ll x); // 删除不能自己独立实现
};

:::

6 个功能就是队列啊,但是队列遍历元素太复杂了,不如链表:

list<pair<int , node>>S; // 远棋子集合

其中 pair 的第一个元素记录片段到当前根的距离,第二个元素记录片段信息。

在换根中涉及的拓展近棋子集合、移动棋子都需要远近棋子集合联合行动:

:::success[联合行动]{open}

ll node::erase(ll x) { // 删除远近棋子集合中的 x 个棋子
    ll ls = 0 , rs = r - l , res = r;
    while(ls <= rs) {
        ll mid = (ls + rs) >> 1; // 远集合中距离为 mid 的棋子个数与近棋子集合中距离为 mid+片段距离 的棋子个数,两者可能相交
        if (qry(mid) + F.qry(mid + S.back().first) >= x) res = l + mid , ls = mid + 1;
        else rs = mid - 1;
    }
    r = res;
    ll v = min((*f)[r] - tag , x - F.qry(r - l + 1 + S.back().first)); // 更新删除棋子的个数
    tag += v;   
    return v;
}
void insert(ll x) { // 插入一个元素,并更新远棋子集合的实际距离
    F.insert(x);
    for (auto &v : S) v.first ++;
}
void erase(ll x) { // 删除 x 个元素,优先删远棋子集合
    while (!S.empty() && F.qry(S.back().first) + S.back().second.cnt() <= x) x -= S.back().second.cnt() , S.pop_back();
    if (!S.empty() && F.qry(S.back().first + S.back().second.r - S.back().second.l + 1) < x) x -= S.back().second.erase(x);
    F.erase(x);   
}

::: 由重链剖分的性质可以证明,远棋子集合的大小是 O(\log n) 的,因此可以暴力更新与遍历。

总时间复杂度均摊下来是 O(n\log n),空间复杂度 O(n)

至此我们就完成这道题了!

5. 参考代码

可能讲的不是很详细,所以为代码加上了丰富的注释!

#include<bits/stdc++.h>
#define cin_fast ios::sync_with_stdio(false) , cin.tie(0) , cout.tie(0)
#define fi first
#define se second
//#define int long long 
#define in(a) a = read()
#define rep(i , a , b) for(int i = a ; i <= b ; i ++)
using namespace std;
typedef long long ll;
const int N = 5e5 + 5 , mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const long long INF = 0x3f3f3f3f3f3f3f3f; 
inline int read() {
    int x = 0;
    char ch = getchar();
    bool S = 0;
    while('9' < ch || ch < '0') S |= ch == '-' , ch = getchar();
    while('0' <= ch && ch <= '9') x = (x << 3) + (x << 1) + ch - '0' , ch = getchar();
    return S ? -x : x;
}
int n , a[N] , b[N];
struct Edge{
    int n , v;
}edge[N << 1];
int head[N] , eid;
void eadd(int u , int v) {
    edge[++ eid].n = head[u] , edge[head[u] = eid].v = v;
}
struct node{ 
    ll tag , l , r;
    vector<ll> *f , *g; // 指针数组,直接指向复制目标的地址,省空间
    node() {}
    node(ll tag , int l , int r , vector<ll> *f , vector<ll> *g) {
        this->tag = tag , this->l = l , this->r = r , this->f = f , this->g = g;
    } 
    ll sum() { // 计算这一部分的贡献
        return (*g)[l] - (*g)[r + 1] - (r - l + 1) * tag;
    }
    ll cnt(){ // 返回这一部分的棋子个数
        if(l <= r) return (*f)[l] - tag;
        else return 0;
    }
    ll qry(int x) { // 返回到截取位置距离为 x 的棋子个数
        return (*f)[l + x] - tag;
    }
    ll erase(ll x); // 删除操作需要和爱丽丝合作,不能自己实现
};
struct DS{ 
    ll tag = 0 , l , r;  // 删除标记与指针
    vector<ll> f , g; // DP 数组,分别记录棋子数目,距离总和
    vector<pair<ll* , ll>>sta; // 回滚栈
    DS() {}
    DS(ll tag , int l , int r , int n) {
        this->tag = tag , this->l = l , this->r = r , this->f.resize(n + 5) , this->g.resize(n + 5);
    } 
    ll sum() { // 这是移动棋子后的最小距离和,用于计算答案
        return g[l + 1] - g[r] - (r - l - 1) * tag;
    }
    ll qry(int x) { // 返回到根距离为 x 的棋子个数
        if(l + x < r) return f[l + x] - tag;
        else return 0;
    }
    int size() { // 有效数组长度
        return r - l + 1;
    }
    void recall(int x) { // 回滚至时刻 x 的历史状态
        while(sta.size() > x) {
            (*sta.back().first) = sta.back().second;
            sta.pop_back();
        }
    }
    void insert(ll x) { // 插入一个元素,并偏移指针 l
        sta.push_back({&l , l --}); // 记录到回滚栈中,方便撤销操作
        sta.push_back({&f[l] , f[l]});
        sta.push_back({&g[l] , g[l]});
        f[l] = f[l + 1] + x , g[l] = g[l + 1] + f[l]; // 转移 
    }
    void erase(ll x) { // 删除 x 个棋子,并偏移指针 r,即移动棋子到当前根
        if(size() <= 1 || x == 0) return;
        sta.push_back({&tag , tag}) , tag += x; // 更新删除的棋子数目
        ll it = lower_bound(f.begin() + l , f.begin() + r , tag , greater<ll>()) - f.begin(); // 二分
        if (r != it) sta.push_back({&r , r}) , r = it;
    }
    void merge(DS o , ll t) { // 合并/消除相邻信息
        if (l + o.size() > r) sta.push_back({&r , r});
        for (int i = l + o.size() - 1 ; i >= l; i --) {
            sta.push_back({&g[i] , g[i]});
            sta.push_back({&f[i] , f[i]});     
            if(i >= r) f[i] = tag; // 超出部分的实际值为 0,设为 tag
            f[i] += o.qry(i - l) * t;   
            g[i] = g[i + 1] + f[i];
        }
        r = max(r , l + o.size() - 1); // 更新数组长度
    }
} F(0 , N , N + N , N + N) , dp[N]; // 近棋子集合、重链信息
list<pair<int , node>>S; // 远棋子集合
ll node::erase(ll x) { // 删除远近棋子集合中的 x 个棋子
    ll ls = 0 , rs = r - l , res = r;
    while(ls <= rs) {
        ll mid = (ls + rs) >> 1; // 远集合中距离为 mid 的棋子个数与近棋子集合中距离为 mid+片段距离 的棋子个数,两者可能相交
        if (qry(mid) + F.qry(mid + S.back().first) >= x) res = l + mid , ls = mid + 1;
        else rs = mid - 1;
    }
    r = res;
    ll v = min((*f)[r] - tag , x - F.qry(r - l + 1 + S.back().first)); // 更新删除棋子的个数
    tag += v;   
    return v;
}
void insert(ll x) { // 插入一个元素,并更新远棋子集合的实际距离
    F.insert(x);
    for (auto &v : S) v.first ++;
}
void erase(ll x) { // 删除 x 个元素,优先删远棋子集合
    while (!S.empty() && F.qry(S.back().first) + S.back().second.cnt() <= x) x -= S.back().second.cnt() , S.pop_back();
    if (!S.empty() && F.qry(S.back().first + S.back().second.r - S.back().second.l + 1) < x) x -= S.back().second.erase(x);
    F.erase(x);   
}
ll ans[N];
int sz[N] , wc[N] , fa[N] , top[N] , dep[N] , ld[N] , stp[N];
void dfs1(int u , int Fa) { // 重剖
    sz[u] = ld[u] = 1 , dep[u] = dep[Fa] + 1;
    for(int i = head[u] ; i ; i = edge[i].n) {
        int v = edge[i].v;
        if(v == Fa) continue;
        dfs1(v , u);
        fa[v] = u , sz[u] += sz[v];
        if(sz[wc[u]] < sz[v]) wc[u] = v;
        ld[u] = max(ld[u] , ld[v] + 1);
    }
}
void dfs2(int u , int Top) { // 单根 DP
    top[u] = Top;
    if(wc[u]) dfs2(wc[u] , Top); // 继承重儿子
    else dp[Top] = DS(0 , dep[u] - dep[Top] + 1 , ld[Top] + 1 , ld[Top] + 1); 
    stp[u] = dp[Top].sta.size(); // 记录当前状态(仅处理完重儿子)的栈大小,方便回滚
    for(int i = head[u] ; i ; i = edge[i].n) {
        int v = edge[i].v;
        if(v == fa[u] || v == wc[u]) continue;
        dfs2(v , v);
        dp[Top].merge(dp[v] , 1); // 合并轻儿子
    }
    dp[Top].insert(a[u]) , dp[Top].erase(a[u] - b[u]); // 插入当前位置并贪心移动棋子
}
void solve(int u) { // 换根 DP
    list<pair<int , node>> now = S;
    int ipos = F.sta.size();
    dp[top[u]].recall(stp[u]); // 调整重链至当前状态
    for(int i = head[u] ; i ; i = edge[i].n) {
        int v = edge[i].v;
        if(v == fa[u] || v == wc[u]) continue;
        F.merge(dp[v] , 1); // 将轻儿子合并到近棋子集合
    }
    // 将重链插入远棋子集合
    S.push_front(make_pair(0 , node(dp[top[u]].tag , dp[top[u]].l , dp[top[u]].r - 1 , &dp[top[u]].f , &dp[top[u]].g)));
    int pos = F.sta.size();
    insert(a[u]) , erase(a[u] - b[u]); // 插入当前位置并贪心移动棋子
    ans[u] = F.sum(); // 计算近棋子集合的距离和
    for (auto &v : S) ans[u] += (v.first - 1) * v.second.cnt() + v.second.sum(); // 计算远棋子集合的距离和    
    F.recall(pos) , S = now; // 恢复原来状态
    for(int i = head[u] ; i ; i = edge[i].n) { // 轻儿子换根部分
        int v = edge[i].v;
        if(v == fa[u] || v == wc[u]) continue;
        F.merge(dp[v] , -1); // 将轻儿子从近棋子集合中移除
        DS x(0 , 0 , ld[v] , ld[v]); // 拆重链
        for (int i = 0; i < ld[v] ; i ++) x.f[i] = dp[top[u]].qry(i) - dp[top[u]].qry(ld[v]);  
        F.merge(x , 1); //轻儿子范围内的插入近棋子集合,范围外的插入远棋子集合
        if(dp[top[u]].l + ld[v] <= dp[top[u]].r - 1) S.push_front(make_pair(ld[v] , 
        node(dp[top[u]].tag , dp[top[u]].l + ld[v] , dp[top[u]].r - 1 , &dp[top[u]].f , &dp[top[u]].g)));
        insert(a[u]) , erase(a[u] - b[u]);
        solve(v); 
        F.recall(pos) , S = now;
    }
    insert(a[u]) , erase(a[u] - b[u]); 
    if(wc[u]) solve(wc[u]); // 换根给重儿子
    F.recall(ipos) , S = now; 
}
signed main() {
    //cin_fast;
    in(n);
    for(int i = 1 ; i <= n ; i ++) in(a[i]);
    for(int i = 1 ; i <= n ; i ++) in(b[i]);
    for(int i = 1 ; i < n ; i ++) {
        int u , v;
        in(u) , in(v);
        eadd(u , v) , eadd(v , u);
    }
    dfs1(1 , 0) , dfs2(1 , 1) , solve(1);
    for(int i = 1 ; i <= n ; i ++) cout << ans[i] << ' ';
    return 0;
}
/*
愿你把曾经的运筹帷幄,换成人生下一程的从容落子。无论走到哪里,决胜千里的本事都跟着你。—— Kei to ARIS
this is ARIS 38
*/