题解 P6748 【Fallen Lord】
囧仙
2020-08-08 20:59:17
一道感觉有一定思维难度的树形 dp 题
- **中位数**
做这道题,首先要明白这个中位数 $\le a_i$ 的本质
根据题目给的定义,中位数就是第 $\left\lfloor\dfrac{k}{2}\right\rfloor+1$ 小的数
由此可以推得,前 $\left\lfloor\dfrac{k}{2}\right\rfloor+1$ 小的数都必须 $\le a_i$
也就是说,至多只能有 $k - \left\lfloor\dfrac{k}{2}\right\rfloor - 1$ 个数 $>a_i$
- **设计 dp**
现在题意已经被我们转化为:
给你一颗树,要给每条边标上一个 $[1,m]$ 中的边权,满足每个点直接相连的边中,至多有 $b_i$ 条边的权 $>a_i$,且边权和最大。
那么这个东西一看就很 dp,因为不与 $i$ 直接相连的边与 $i$ 点都是无关的,就没有后效性
但是问题就来了,一般在树上 dp 我们都是设 $dp_i$ 为以 $i$ 为根的子树的某些信息,然后用一遍 dfs 来进行转移,但是当我们给以 $i$ 为根的子树中的所有边都标上边权时,我们却没法确定 $i$ 点加上连接 $i$ 与 $i$ 父亲的边是否依然合法
所以我们改一下状态,改为 $dp_i$ 表示给以 $i$ 为根的子树内的所有边都标上了边权,且给连接 $i$ 与 $i$ 的父亲的边也标上了边权,以 $i$ 为根的子树内的所有点都合法的最大边权和
但是这样就还有一个问题:连接 $i$ 与 $i$ 父亲的边可能大于 $a_{fa}$,也可能不大于 $a_{fa}$
所以还得分类,改为 $dp_{i,0}$ 表示给以 $i$ 为根的子树内的所有边都标上了边权,且给连接 $i$ 与 $i$ 父亲的边也标上了边权,且这条边的权 $\le a_{fa}$ 的最大边权和
$dp_{i,1}$ 同理,只不过这条边的边权 $> a_{fa}$
- **dp 转移**
dp 只有一个没有后效性的状态设计可不够,还要有转移
那么在从 $now$ 的儿子向 $now$ 转移时,实际上每个 $dp_{son,1}$ 就使得 $now$ 点相连的边多了一条 $>a_{now}$ 的
所以我们要在所有 $now$ 的儿子中,选择一些 $dp_{son,0}$,选择一些 $dp_{son,1}$,最后确定连接 $now$ 与 $fa$ 的边的边权
实际上,确定连接 $now$ 与 $fa$ 的边的边权反而是最重要的,要最先做,因为只有确定了这个,你才能知道可以最多选多少个 $dp_{son,1}$
1. 连接 $now$ 与 $fa$ 的边小于等于 $now$ 与 $fa$ 的权,则我们要选至多 $b_{now}$ 个 $dp_{son,1}$,其余选 $dp_{son,0}$,边权当然是 $\min(a_{now},a_{fa})$ 最优,状态转移到 $dp_{now,0}$
2. 连接 $now$ 与 $fa$ 的边小于 $now$ 的权,大于 $fa$ 的权,实际上选择的内容和 1 是一样的,只不过边权改为 $a_{now}$,转移到 $dp_{now,1}$
3. 连接 $now$ 与 $fa$ 的边大于 $now$ 的权,小于 $fa$ 的权,则我们要选至多 $b_{now} - 1$ 个 $dp_{son,1}$,其余选 $dp_{son,0}$,边权选 $a_{fa}$,转移到 $dp_{now,0}$
4. 连接 $now$ 与 $fa$ 的边大于 $now,fa$ 的权,选择的内容与 3 一样,边权 $m$ 最优,转移到 $dp_{now,1}$
- **解决选择的问题**
现在问题被我们转化为这样:
给你两个长度为 $n$ 的序列 $a,b$,你可以使最多 $k$ 个 $i$ 的贡献为 $b_i$,其余贡献为 $a_i$,要求贡献最大
那么这个感觉就很可以用贪心来做,设 $c_i = b_i - a_i$,显然要选 $c_i$ 最大的 $k$ 个(如果 $c_i < 0$ 就不选了)
- **总结**
然后这道题就做完了,感觉还是挺复杂的,不知道是不是我把问题弄复杂了
还有一些小细节,比如转移根节点的时候不能标根节点与根节点父亲的权
最后放一下代码:
```cpp
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
int n,m;
int a[500005];
int cnt;
int head[500005];
struct eg{
int to,nxt;
}edge[1000005];
void make(int u,int v){
edge[++cnt].to = v;
edge[cnt].nxt = head[u];
head[u] = cnt;
}
int deg[500005];
//给每个边标一个边权,每个点相连的边中,至多 k 条边的权 > a[i]
ll dp[500005][2];//dp[i][0] 表示以 i 为根的子树内的所有边,以及连接 i 与 i 父亲的边都标上了权且合法,连接 i 与 i 父亲的边 < a[fa[i]]
//dp[i][1] 表示大于
//前 [k / 2] + 1 个 <=
//k - [k / 2] - 1 个大于
//1. 连接 i 与 i 父亲的边小于 i 与 i 父亲的权,则 dp[i][0] = 前 k 大的 dp[son][1] + 其他的 dp[son][0] + min(a[i],a[fa[i]])
//2. 连接 i 与 i 父亲的边小于 i 的权,大于 i 父亲的权 dp[i][1] = 前 k 大的 dp[son][1] + 其他的 dp[son][0] + a[i]
//3. 连接 i 与 i 父亲的边大于 i 的权,小于 i 父亲的权 dp[i][0] = 前 k - 1 大的 dp[son][1] + 其他的 dp[son][0] + a[fa[i]]
//4. 连接 i 与 i 父亲的边大于 i 的权,大于 i 父亲的权 dp[i][1] = 前 k - 1 大的 dp[son][1] + 其他的 dp[son][0] + m
int len;
int Q[500005];
bool cmp(int x,int y){
return (dp[x][1] - dp[x][0]) > (dp[y][1] - dp[y][0]);
}
ll max(ll a,ll b){
if(a < b) return b;
return a;
}
ll min(ll a,ll b){
if(a < b) return a;
return b;
}
void dfs(int now,int fa){
for(int i = head[now];i;i = edge[i].nxt){
if(edge[i].to == fa) continue;
dfs(edge[i].to,now);
}
len = 0;
for(int i = head[now];i;i = edge[i].nxt){
if(edge[i].to == fa) continue;
Q[++len] = edge[i].to;
}
/*printf("node %d: fa %d sons : \n",now,fa);
for(int i = 1;i <= len;i++) printf("%d ",Q[i]);
printf("\n"); */
std::sort(Q + 1,Q + len + 1,cmp);
ll tmp = 0;
for(int i = 1;i <= len;i++){
if(i <= deg[now]) tmp += dp[Q[i]][1];
else tmp += dp[Q[i]][0];
}
if(now != 1) tmp += min(a[now],a[fa]);
dp[now][0] = max(dp[now][0],tmp);
if(a[now] > a[fa]){
tmp = 0;
for(int i = 1;i <= len;i++){
if(i <= deg[now]) tmp += dp[Q[i]][1];
else tmp += dp[Q[i]][0];
}
if(now != 1) tmp += a[now];
dp[now][1] = max(dp[now][1],tmp);
}
dp[now][1] = max(dp[now][0],dp[now][1]);
if(!deg[now]) return;
tmp = 0;
if(a[now] < a[fa]){
for(int i = 1;i <= len;i++){
if(i < deg[now]) tmp += dp[Q[i]][1];
else tmp += dp[Q[i]][0];
}
if(now != 1) tmp += a[fa];
dp[now][0] = max(dp[now][0],tmp);
}
tmp = 0;
for(int i = 1;i <= len;i++){
if(i < deg[now]) tmp += dp[Q[i]][1];
else tmp += dp[Q[i]][0];
}
if(now != 1) tmp += m;
dp[now][1] = max(dp[now][1],tmp);
dp[now][1] = max(dp[now][0],dp[now][1]);
}
int main(){
memset(dp,0xcf,sizeof(dp));
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;i++){
scanf("%d",&a[i]);
}
for(int i = 1;i < n;i++){
int u,v;
scanf("%d%d",&u,&v);
make(u,v);
make(v,u);
deg[u]++;
deg[v]++;
}
for(int i = 1;i <= n;i++){
deg[i] = deg[i] - deg[i] / 2 - 1;
}
a[0] = m;
dfs(1,0);
printf("%lld\n",max(dp[1][0],max(dp[1][1],-1)));
return 0;
}
```