题解:P12481 [集训队互测 2024] 运筹帷幄/ ARIS0_0 - 38
这是一道树形 DP,但难点不在 DP,而在维护转移与换根。核心数据结构是数组和指针,维护思路相当巧妙。
1. 简单贪心
假设当前处于子树
2. 朴素 DP
我们设
易得转移方程:
合并完所有儿子的贡献后,暴力枚举到一个位置
随后对每个根都跑一遍 DP,单次 DP 的时间复杂度为
3. 初步优化
朴素 DP 很明显可以用长剖优化,这样可以将空间复杂度优化至
具体来说,对于每个长链,我们需要维护 DP 数组
对于叶子节点,我们令其所在的长链的
至于最优棋子调整,先二分找到
这样我们就把单次 DP 优化到了
4. 最终优化
长剖在换根下不好维护,于是我们把目光转向重剖,重剖显然也可以优化朴素 DP,且时间复杂度依旧是
相较于长剖,重剖具有更多优秀的性质,可以帮助我们解决这个问题。
我们需要维护三个信息,分别是重链信息、近棋子集合和远棋子集合。
当我们完成单根 DP 时,我们就求出了所有重链信息。然后考虑换根,我们先考虑怎么求答案,假设目前我们有近棋子集合、远棋子集合与重儿子的信息,这些信息恰好覆盖了整棵树,现在考虑怎么移动棋子最优。首先要优先删远棋子集合,即移动远棋子集合中的棋子。若删完远棋子集合,再考虑重儿子和近棋子集合,若未删完,则删除其一部分棋子,做一个修改即可,删除方式均为二分。但这三者的距离范围可能有交集,需要揉一块二分,实现起来还是比较复杂的。
接下来考虑怎么手搓一个数据结构维护这三个玩意,让我们来构思一下这个神秘的结构:首先,它可以动态分配内存,根据其要维护的距离范围开数组;其次,它拥有两个指针来维护有效数组范围,且可以高效访问数组内部元素;最后,它需要是一个可撤销数据结构。
于是我们自然就想到了动态数组与指针。我们开一个结构体,分别维护动态数组
- 求出其距离范围内的最小距离和
- 求出到根距离为
x 的棋子个数 - 求出有效数组长度
- 回滚至某一时刻的历史状态
- 拓展一个新节点
- 按贪心思路删除
x 个棋子 - 合并、消除相邻信息
虽然要实现的东西有点多,但是代码还是比较好写的:
:::success[核心数据结构]{open}
struct DS{
ll tag = 0 , l , r; // 删除标记与指针
vector<ll> f , g; // DP 数组,分别记录棋子数目,距离总和
vector<pair<ll* , ll>>sta; // 回滚栈
DS() {}
DS(ll tag , int l , int r , int n) {
this->tag = tag , this->l = l , this->r = r , this->f.resize(n + 5) , this->g.resize(n + 5);
}
ll sum() { // 这是移动棋子后的最小距离和,用于计算答案
return g[l + 1] - g[r] - (r - l - 1) * tag;
}
ll qry(int x) { // 返回到根距离为 x 的棋子个数
if(l + x < r) return f[l + x] - tag;
else return 0;
}
int size() { // 有效数组长度
return r - l + 1;
}
void recall(int x) { // 回滚至时刻 x 的历史状态
while(sta.size() > x) {
(*sta.back().first) = sta.back().second;
sta.pop_back();
}
}
void insert(ll x) { // 插入一个元素,并偏移指针 l
sta.push_back({&l , l --}); // 记录到回滚栈中,方便撤销操作
sta.push_back({&f[l] , f[l]});
sta.push_back({&g[l] , g[l]});
f[l] = f[l + 1] + x , g[l] = g[l + 1] + f[l]; // 转移
}
void erase(ll x) { // 删除 x 个棋子,并偏移指针 r,即移动棋子到当前根
if(size() <= 1 || x == 0) return;
sta.push_back({&tag , tag}) , tag += x; // 更新删除的棋子数目
ll it = lower_bound(f.begin() + l , f.begin() + r , tag , greater<ll>()) - f.begin(); // 二分
if (r != it) sta.push_back({&r , r}) , r = it;
}
void merge(DS o , ll t) { // 合并/消除相邻信息
if (l + o.size() > r) sta.push_back({&r , r});
for (int i = l + o.size() - 1 ; i >= l; i --) {
sta.push_back({&g[i] , g[i]});
sta.push_back({&f[i] , f[i]});
if(i >= r) f[i] = tag; // 超出部分的实际值为 0,设为 tag
f[i] += o.qry(i - l) * t;
g[i] = g[i + 1] + f[i];
}
r = max(r , l + o.size() - 1); // 更新数组长度
}
} F(0 , N , N + N , N + N) , dp[N]; // 近棋子集合、重链信息
::: 该结构可以维护近棋子集合与重链信息,但远棋子集合不需要这么复杂的结构,考虑到其设计需要完整的换根思路,我们先讲完换根部分再设计它。
所以现在我们可以思考怎样具体实现换根了!
先求以当前节点为根的答案:首先我们调整重链为仅处理完重儿子的状态,直接回滚就可以了,随后,将轻儿子合并到近棋子集合,两者距离差为
注意完成当前根的求解后要撤销棋子的移动,恢复远近棋子集合到原来的状态。
接下来考虑给轻儿子换根:首先将轻儿子从近棋子集合中移除,随后拆重链,将重儿子的信息划分成两个部分,距离在轻儿子子树高度范围内的插入近棋子集合,超出部分则插入远棋子集合,随后转移近棋子集合拓展距离为
重儿子则没有那么麻烦,拓展近棋子集合、移动棋子后直接递归就可以了,注意最后要将远近棋子集合恢复到最开始的状态。
具体思路就是这样,但是你可以发现,近棋子集合只需要维护一个结构体就可以了,但远棋子集合不一样,它里面的元素是若干个从重链上拆分出来的片段,且它们表示的最近距离到当前根的距离是按照插入时间升序排列的,因此它是一个线性数据结构,需要实现以下功能:
- 求出片段中所有棋子到其截取位置的最小距离和
- 求出这片段中的棋子个数
- 求出到截取位置距离为
x 的棋子个数 - 复制某个重链信息的一段前缀
- 按贪心思路结合近棋子集合删除若干个棋子
- 先进先出的结构,可以遍历内部元素
:::success[辅助数据结构]{open}
struct node{
ll tag , l , r;
vector<ll> *f , *g; // 指针数组,直接指向复制目标的地址,省空间
node() {}
node(ll tag , int l , int r , vector<ll> *f , vector<ll> *g) {
this->tag = tag , this->l = l , this->r = r , this->f = f , this->g = g;
}
ll sum() { // 计算这一部分的贡献
return (*g)[l] - (*g)[r + 1] - (r - l + 1) * tag;
}
ll cnt(){ // 返回这一部分的棋子个数
if(l <= r) return (*f)[l] - tag;
else return 0;
}
ll qry(int x) { // 返回到截取位置距离为 x 的棋子个数
return (*f)[l + x] - tag;
}
ll erase(ll x); // 删除不能自己独立实现
};
:::
第
list<pair<int , node>>S; // 远棋子集合
其中 pair 的第一个元素记录片段到当前根的距离,第二个元素记录片段信息。
在换根中涉及的拓展近棋子集合、移动棋子都需要远近棋子集合联合行动:
:::success[联合行动]{open}
ll node::erase(ll x) { // 删除远近棋子集合中的 x 个棋子
ll ls = 0 , rs = r - l , res = r;
while(ls <= rs) {
ll mid = (ls + rs) >> 1; // 远集合中距离为 mid 的棋子个数与近棋子集合中距离为 mid+片段距离 的棋子个数,两者可能相交
if (qry(mid) + F.qry(mid + S.back().first) >= x) res = l + mid , ls = mid + 1;
else rs = mid - 1;
}
r = res;
ll v = min((*f)[r] - tag , x - F.qry(r - l + 1 + S.back().first)); // 更新删除棋子的个数
tag += v;
return v;
}
void insert(ll x) { // 插入一个元素,并更新远棋子集合的实际距离
F.insert(x);
for (auto &v : S) v.first ++;
}
void erase(ll x) { // 删除 x 个元素,优先删远棋子集合
while (!S.empty() && F.qry(S.back().first) + S.back().second.cnt() <= x) x -= S.back().second.cnt() , S.pop_back();
if (!S.empty() && F.qry(S.back().first + S.back().second.r - S.back().second.l + 1) < x) x -= S.back().second.erase(x);
F.erase(x);
}
:::
由重链剖分的性质可以证明,远棋子集合的大小是
总时间复杂度均摊下来是
至此我们就完成这道题了!
5. 参考代码
可能讲的不是很详细,所以为代码加上了丰富的注释!
#include<bits/stdc++.h>
#define cin_fast ios::sync_with_stdio(false) , cin.tie(0) , cout.tie(0)
#define fi first
#define se second
//#define int long long
#define in(a) a = read()
#define rep(i , a , b) for(int i = a ; i <= b ; i ++)
using namespace std;
typedef long long ll;
const int N = 5e5 + 5 , mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const long long INF = 0x3f3f3f3f3f3f3f3f;
inline int read() {
int x = 0;
char ch = getchar();
bool S = 0;
while('9' < ch || ch < '0') S |= ch == '-' , ch = getchar();
while('0' <= ch && ch <= '9') x = (x << 3) + (x << 1) + ch - '0' , ch = getchar();
return S ? -x : x;
}
int n , a[N] , b[N];
struct Edge{
int n , v;
}edge[N << 1];
int head[N] , eid;
void eadd(int u , int v) {
edge[++ eid].n = head[u] , edge[head[u] = eid].v = v;
}
struct node{
ll tag , l , r;
vector<ll> *f , *g; // 指针数组,直接指向复制目标的地址,省空间
node() {}
node(ll tag , int l , int r , vector<ll> *f , vector<ll> *g) {
this->tag = tag , this->l = l , this->r = r , this->f = f , this->g = g;
}
ll sum() { // 计算这一部分的贡献
return (*g)[l] - (*g)[r + 1] - (r - l + 1) * tag;
}
ll cnt(){ // 返回这一部分的棋子个数
if(l <= r) return (*f)[l] - tag;
else return 0;
}
ll qry(int x) { // 返回到截取位置距离为 x 的棋子个数
return (*f)[l + x] - tag;
}
ll erase(ll x); // 删除操作需要和爱丽丝合作,不能自己实现
};
struct DS{
ll tag = 0 , l , r; // 删除标记与指针
vector<ll> f , g; // DP 数组,分别记录棋子数目,距离总和
vector<pair<ll* , ll>>sta; // 回滚栈
DS() {}
DS(ll tag , int l , int r , int n) {
this->tag = tag , this->l = l , this->r = r , this->f.resize(n + 5) , this->g.resize(n + 5);
}
ll sum() { // 这是移动棋子后的最小距离和,用于计算答案
return g[l + 1] - g[r] - (r - l - 1) * tag;
}
ll qry(int x) { // 返回到根距离为 x 的棋子个数
if(l + x < r) return f[l + x] - tag;
else return 0;
}
int size() { // 有效数组长度
return r - l + 1;
}
void recall(int x) { // 回滚至时刻 x 的历史状态
while(sta.size() > x) {
(*sta.back().first) = sta.back().second;
sta.pop_back();
}
}
void insert(ll x) { // 插入一个元素,并偏移指针 l
sta.push_back({&l , l --}); // 记录到回滚栈中,方便撤销操作
sta.push_back({&f[l] , f[l]});
sta.push_back({&g[l] , g[l]});
f[l] = f[l + 1] + x , g[l] = g[l + 1] + f[l]; // 转移
}
void erase(ll x) { // 删除 x 个棋子,并偏移指针 r,即移动棋子到当前根
if(size() <= 1 || x == 0) return;
sta.push_back({&tag , tag}) , tag += x; // 更新删除的棋子数目
ll it = lower_bound(f.begin() + l , f.begin() + r , tag , greater<ll>()) - f.begin(); // 二分
if (r != it) sta.push_back({&r , r}) , r = it;
}
void merge(DS o , ll t) { // 合并/消除相邻信息
if (l + o.size() > r) sta.push_back({&r , r});
for (int i = l + o.size() - 1 ; i >= l; i --) {
sta.push_back({&g[i] , g[i]});
sta.push_back({&f[i] , f[i]});
if(i >= r) f[i] = tag; // 超出部分的实际值为 0,设为 tag
f[i] += o.qry(i - l) * t;
g[i] = g[i + 1] + f[i];
}
r = max(r , l + o.size() - 1); // 更新数组长度
}
} F(0 , N , N + N , N + N) , dp[N]; // 近棋子集合、重链信息
list<pair<int , node>>S; // 远棋子集合
ll node::erase(ll x) { // 删除远近棋子集合中的 x 个棋子
ll ls = 0 , rs = r - l , res = r;
while(ls <= rs) {
ll mid = (ls + rs) >> 1; // 远集合中距离为 mid 的棋子个数与近棋子集合中距离为 mid+片段距离 的棋子个数,两者可能相交
if (qry(mid) + F.qry(mid + S.back().first) >= x) res = l + mid , ls = mid + 1;
else rs = mid - 1;
}
r = res;
ll v = min((*f)[r] - tag , x - F.qry(r - l + 1 + S.back().first)); // 更新删除棋子的个数
tag += v;
return v;
}
void insert(ll x) { // 插入一个元素,并更新远棋子集合的实际距离
F.insert(x);
for (auto &v : S) v.first ++;
}
void erase(ll x) { // 删除 x 个元素,优先删远棋子集合
while (!S.empty() && F.qry(S.back().first) + S.back().second.cnt() <= x) x -= S.back().second.cnt() , S.pop_back();
if (!S.empty() && F.qry(S.back().first + S.back().second.r - S.back().second.l + 1) < x) x -= S.back().second.erase(x);
F.erase(x);
}
ll ans[N];
int sz[N] , wc[N] , fa[N] , top[N] , dep[N] , ld[N] , stp[N];
void dfs1(int u , int Fa) { // 重剖
sz[u] = ld[u] = 1 , dep[u] = dep[Fa] + 1;
for(int i = head[u] ; i ; i = edge[i].n) {
int v = edge[i].v;
if(v == Fa) continue;
dfs1(v , u);
fa[v] = u , sz[u] += sz[v];
if(sz[wc[u]] < sz[v]) wc[u] = v;
ld[u] = max(ld[u] , ld[v] + 1);
}
}
void dfs2(int u , int Top) { // 单根 DP
top[u] = Top;
if(wc[u]) dfs2(wc[u] , Top); // 继承重儿子
else dp[Top] = DS(0 , dep[u] - dep[Top] + 1 , ld[Top] + 1 , ld[Top] + 1);
stp[u] = dp[Top].sta.size(); // 记录当前状态(仅处理完重儿子)的栈大小,方便回滚
for(int i = head[u] ; i ; i = edge[i].n) {
int v = edge[i].v;
if(v == fa[u] || v == wc[u]) continue;
dfs2(v , v);
dp[Top].merge(dp[v] , 1); // 合并轻儿子
}
dp[Top].insert(a[u]) , dp[Top].erase(a[u] - b[u]); // 插入当前位置并贪心移动棋子
}
void solve(int u) { // 换根 DP
list<pair<int , node>> now = S;
int ipos = F.sta.size();
dp[top[u]].recall(stp[u]); // 调整重链至当前状态
for(int i = head[u] ; i ; i = edge[i].n) {
int v = edge[i].v;
if(v == fa[u] || v == wc[u]) continue;
F.merge(dp[v] , 1); // 将轻儿子合并到近棋子集合
}
// 将重链插入远棋子集合
S.push_front(make_pair(0 , node(dp[top[u]].tag , dp[top[u]].l , dp[top[u]].r - 1 , &dp[top[u]].f , &dp[top[u]].g)));
int pos = F.sta.size();
insert(a[u]) , erase(a[u] - b[u]); // 插入当前位置并贪心移动棋子
ans[u] = F.sum(); // 计算近棋子集合的距离和
for (auto &v : S) ans[u] += (v.first - 1) * v.second.cnt() + v.second.sum(); // 计算远棋子集合的距离和
F.recall(pos) , S = now; // 恢复原来状态
for(int i = head[u] ; i ; i = edge[i].n) { // 轻儿子换根部分
int v = edge[i].v;
if(v == fa[u] || v == wc[u]) continue;
F.merge(dp[v] , -1); // 将轻儿子从近棋子集合中移除
DS x(0 , 0 , ld[v] , ld[v]); // 拆重链
for (int i = 0; i < ld[v] ; i ++) x.f[i] = dp[top[u]].qry(i) - dp[top[u]].qry(ld[v]);
F.merge(x , 1); //轻儿子范围内的插入近棋子集合,范围外的插入远棋子集合
if(dp[top[u]].l + ld[v] <= dp[top[u]].r - 1) S.push_front(make_pair(ld[v] ,
node(dp[top[u]].tag , dp[top[u]].l + ld[v] , dp[top[u]].r - 1 , &dp[top[u]].f , &dp[top[u]].g)));
insert(a[u]) , erase(a[u] - b[u]);
solve(v);
F.recall(pos) , S = now;
}
insert(a[u]) , erase(a[u] - b[u]);
if(wc[u]) solve(wc[u]); // 换根给重儿子
F.recall(ipos) , S = now;
}
signed main() {
//cin_fast;
in(n);
for(int i = 1 ; i <= n ; i ++) in(a[i]);
for(int i = 1 ; i <= n ; i ++) in(b[i]);
for(int i = 1 ; i < n ; i ++) {
int u , v;
in(u) , in(v);
eadd(u , v) , eadd(v , u);
}
dfs1(1 , 0) , dfs2(1 , 1) , solve(1);
for(int i = 1 ; i <= n ; i ++) cout << ans[i] << ' ';
return 0;
}
/*
愿你把曾经的运筹帷幄,换成人生下一程的从容落子。无论走到哪里,决胜千里的本事都跟着你。—— Kei to ARIS
this is ARIS 38
*/