题解 P4383 【[八省联考2018]林克卡特树lct】
Marser
2019-09-08 22:49:17
## 题意
给定一棵 $n$ 个点的树,边权有正有负,要求在树上选出 $k+1$ 条链,使得其权值之和最大。
看起来很不好做,那我们先考虑一下,对于60分的部分分如何处理。
我们可以这样设计状态:
令 $f[i][j][0/1/2]$ 表示在 $i$ 节点的子树内,已经有 $j$ 条**完整的链**,当前 $i$ 节点的度数为 $0/1/2$ 的最大价值。度数为 $0$ 时,这个点没有连边。度数为 $1$ 时,这个点拖着一条未完成的链,而这条链不计入 $j$ 。度数为 $2$ 时,这个点被一条连接两个不同子树的链穿过。
可以分类讨论出如下转移:
首先,我们约定在每个节点的全部转移结束时,进行一次更新:
$f[i][j][0]=\max\{f[i][j][0],f[i][j-1][1],f[i][j][2]\}$
这样,我们就把 $i$ 节点的全部最优解统计了出来。对于度数为 $0/2$ 的情况可以直接合并,而对于度数为 $1$ 的情况,要先在 $i$ 节点处把当前 $i$ 节点拖着的一条链断掉,然后将这条链计入总数,合并。
$f[i][j][0]=\max_{l=1}^{j-1} \{ f[i][j][0],f[i][l][0]+f[x][j-l][0]\} $
显然,如果当前节点的度数为 $0$ ,肯定不能连边,只能取子节点的最优解。
$f[i][j][1]=\max_{l=1}^{j-1} \{ f[i][j][1],f[i][l][1]+f[x][j-l][0],f[i][l][0]+f[x][j-l][1]+w(i,x)\} $
如果要求度数为 $1$ ,可以继承之前的结果,不连边。同样,可以连上这条边,继承子节点拖着的一条未完成的链,同时取这条边的权值。
$f[i][j][2]=\max_{l=1}^{j-1} \{ f[i][j][2],f[i][l][2]+f[x][j-l][0],f[i][l][1]+f[x][j-l-1][1]+w(i,x)\} $
要求度数为 $2$ ,同样可以继承之前的结果,取子节点的最优解。也可以将之前 $i$ 节点挂着的链与子节点挂着的链连接起来,就新完成了一条链。
$f[i][j][1]=\max\{f[i][j][1],f[i][j][0]+f[x][0][1]+w(i,x)\}$
$f[i][0][1]=\max\{f[i][0][1],f[i][0][0]+f[x][0][1]+w(i,x)\}$
处理从子节点出发的链。
这样我们可以得到一个 $O(nk^2)$ 的dp做法,可以拿到45(35)分。
考虑如何优化。我们注意到,每次增加一条链,得到的收益是单调不增的。
每次增加一条链有两种做法,一种是新选一条链,一种是将一条链分成两条。每种操作的收益是一定的,因此每次都会选择最大收益的操作。而由于边权不会改变,后续不可能有操作会有更大的收益。因此,我们可以口胡出这是一个上凸的函数。
记 $f[1][x][0]$ 为 $F(x)$ ,我们知道 $F(x)$ 是一个上凸函数。并且,恒有 $F(0)=F(n)=0$ 。因此,我们可以考虑用一条斜率一定的直线去切这个凸壳。设这条直线为 $l:y=ax+b$ ,根据上凸函数的特性,斜率一定时,切线的截距一定大于割线。
我们令切点为 $(t,F(t))$ ,则切线满足 $a*t+b=F(t)$ ,即 $b=F(t) - a*t$ 。可以发现, $b$ 的表达式就相当于给每个物品赋上额外的权值 $-t$ 时,dp所得的最优解。这样,我们只需要一次朴素的dp就可以求出截距。
而如何获取斜率呢?再一次,由于函数上凸,我们可以发现当斜率不断增大时,对应的切点横坐标也在不断左移。这样,我们就可以二分求相应的斜率了。每次check的时候顺便记录一下最优解选取了多少个,从而判断应该如何调整二分区间。
还有一个问题,有可能出现凸壳上多个点共线的情况。为了处理这种情况,我们可以限定在dp过程中,权值相同时优先取选择次数更小的转移。这样,我们求出来的次数就是当前切点的左边界,二分时判断一下就可以了。
回到这题,套用60(45)分做法的dp方式,去掉有关次数的限制,每次dp的复杂度就变成 $O(n)$ 。总复杂度为 $O(n \log k)$
45分代码
```cpp
#include<bits/stdc++.h>
#define reg register
typedef long long ll;
using namespace std;
const int MN=3e5+5;
int to[MN<<1],nxt[MN<<1],c[MN<<1],h[MN],cnt;
inline void ins(int s,int t,int w){
to[++cnt]=t;nxt[cnt]=h[s];c[cnt]=w;h[s]=cnt;
to[++cnt]=s;nxt[cnt]=h[t];c[cnt]=w;h[t]=cnt;
}
#define chkmax(a,b) ((a)<(b)?(a)=(b):0)
int n,K;
int f[MN][105][3];
void dfs(int st,int fa=0){
f[st][0][0]=f[st][0][1]=f[st][1][2]=0;
for(reg int i=h[st];i;i=nxt[i]){
if(to[i]==fa)continue;
dfs(to[i],st);
for(reg int j=K;j;j--){
chkmax(f[st][j][1],f[st][j][0]+f[to[i]][0][1]+c[i]);
for(reg int k=j-1;~k;k--){
chkmax(f[st][j][0],f[st][k][0]+f[to[i]][j-k][0]);
chkmax(f[st][j][1],max(f[st][k][1]+f[to[i]][j-k][0],f[st][k][0]+f[to[i]][j-k][1]+c[i]));
chkmax(f[st][j][2],max(f[st][k][2]+f[to[i]][j-k][0],f[st][k][1]+f[to[i]][j-k-1][1]+c[i]));
}
}
chkmax(f[st][0][1],f[to[i]][0][1]+c[i]);
}
for(reg int i=1;i<=K;i++)
chkmax(f[st][i][0],max(f[st][i-1][1],f[st][i][2]));
}
int main(){
scanf("%d%d",&n,&K);K++;
for(reg int i=1,s,t,v;i<n;i++)
scanf("%d%d%d",&s,&t,&v),ins(s,t,v);
memset(f,~0x3f,sizeof(f));dfs(1);
printf("%d\n",f[1][K][0]);
return 0;
}
```
100分代码
```cpp
#include<bits/stdc++.h>
#define reg register
typedef long long ll;
using namespace std;
const int MN=3e5+5;
int to[MN<<1],nxt[MN<<1],c[MN<<1],h[MN],cnt;
inline void ins(int s,int t,int w){
to[++cnt]=t;nxt[cnt]=h[s];c[cnt]=w;h[s]=cnt;
to[++cnt]=s;nxt[cnt]=h[t];c[cnt]=w;h[t]=cnt;
}
#define chkmax(a,b) ((a)<(b)?(a)=(b):0)
int n,k;
ll l,r,mid;
struct data{
ll val;int pos;
data(ll x=0,int y=0):val(x),pos(y){}
friend bool operator<(data a,data b){
return a.val==b.val?a.pos>b.pos:a.val<b.val;
}
friend data operator+(data a,data b){
return data(a.val+b.val,a.pos+b.pos);
}
friend data operator+(data a,ll b){
return data(a.val+b,a.pos);
}
}f[MN][3],tmp;
int fa[MN];
void getf(int st){
for(reg int i=h[st];i;i=nxt[i])
if(to[i]!=fa[st])fa[to[i]]=st,getf(to[i]);
}
void dfs(int st){
f[st][0]=f[st][1]=f[st][2]=data();
chkmax(f[st][2],tmp);
for(reg int i=h[st];i;i=nxt[i]){
if(to[i]==fa[st])continue;dfs(to[i]);
chkmax(f[st][2],max(f[st][2]+f[to[i]][0],f[st][1]+f[to[i]][1]+c[i]+tmp));
chkmax(f[st][1],max(f[st][1]+f[to[i]][0],f[st][0]+f[to[i]][1]+c[i]));
chkmax(f[st][0],f[st][0]+f[to[i]][0]);
}
chkmax(f[st][0],max(f[st][1]+tmp,f[st][2]));
}
int main(){
scanf("%d%d",&n,&k);k++;
for(reg int i=1,s,t,w;i<n;i++)
scanf("%d%d%d",&s,&t,&w),ins(s,t,w);
l=-1e12;r=1e12;getf(1);
while(l<r){
mid=(double)(l+r)/2-0.5;
tmp=data(-mid,1);dfs(1);
if(f[1][0].pos==k){
printf("%lld\n",f[1][0].val+mid*k);
return 0;
}
if(f[1][0].pos>k)l=mid+1;
else r=mid;
}
mid=l;tmp=data(-mid,1);dfs(1);
printf("%lld\n",f[1][0].val+mid*k);
return 0;
}
```