题解 P4383 【[八省联考2018]林克卡特树lct】

Marser

2019-09-08 22:49:17

Solution

## 题意 给定一棵 $n$ 个点的树,边权有正有负,要求在树上选出 $k+1$ 条链,使得其权值之和最大。 看起来很不好做,那我们先考虑一下,对于60分的部分分如何处理。 我们可以这样设计状态: 令 $f[i][j][0/1/2]$ 表示在 $i$ 节点的子树内,已经有 $j$ 条**完整的链**,当前 $i$ 节点的度数为 $0/1/2$ 的最大价值。度数为 $0$ 时,这个点没有连边。度数为 $1$ 时,这个点拖着一条未完成的链,而这条链不计入 $j$ 。度数为 $2$ 时,这个点被一条连接两个不同子树的链穿过。 可以分类讨论出如下转移: 首先,我们约定在每个节点的全部转移结束时,进行一次更新: $f[i][j][0]=\max\{f[i][j][0],f[i][j-1][1],f[i][j][2]\}$ 这样,我们就把 $i$ 节点的全部最优解统计了出来。对于度数为 $0/2$ 的情况可以直接合并,而对于度数为 $1$ 的情况,要先在 $i$ 节点处把当前 $i$ 节点拖着的一条链断掉,然后将这条链计入总数,合并。 $f[i][j][0]=\max_{l=1}^{j-1} \{ f[i][j][0],f[i][l][0]+f[x][j-l][0]\} $ 显然,如果当前节点的度数为 $0$ ,肯定不能连边,只能取子节点的最优解。 $f[i][j][1]=\max_{l=1}^{j-1} \{ f[i][j][1],f[i][l][1]+f[x][j-l][0],f[i][l][0]+f[x][j-l][1]+w(i,x)\} $ 如果要求度数为 $1$ ,可以继承之前的结果,不连边。同样,可以连上这条边,继承子节点拖着的一条未完成的链,同时取这条边的权值。 $f[i][j][2]=\max_{l=1}^{j-1} \{ f[i][j][2],f[i][l][2]+f[x][j-l][0],f[i][l][1]+f[x][j-l-1][1]+w(i,x)\} $ 要求度数为 $2$ ,同样可以继承之前的结果,取子节点的最优解。也可以将之前 $i$ 节点挂着的链与子节点挂着的链连接起来,就新完成了一条链。 $f[i][j][1]=\max\{f[i][j][1],f[i][j][0]+f[x][0][1]+w(i,x)\}$ $f[i][0][1]=\max\{f[i][0][1],f[i][0][0]+f[x][0][1]+w(i,x)\}$ 处理从子节点出发的链。 这样我们可以得到一个 $O(nk^2)$ 的dp做法,可以拿到45(35)分。 考虑如何优化。我们注意到,每次增加一条链,得到的收益是单调不增的。 每次增加一条链有两种做法,一种是新选一条链,一种是将一条链分成两条。每种操作的收益是一定的,因此每次都会选择最大收益的操作。而由于边权不会改变,后续不可能有操作会有更大的收益。因此,我们可以口胡出这是一个上凸的函数。 记 $f[1][x][0]$ 为 $F(x)$ ,我们知道 $F(x)$ 是一个上凸函数。并且,恒有 $F(0)=F(n)=0$ 。因此,我们可以考虑用一条斜率一定的直线去切这个凸壳。设这条直线为 $l:y=ax+b$ ,根据上凸函数的特性,斜率一定时,切线的截距一定大于割线。 我们令切点为 $(t,F(t))$ ,则切线满足 $a*t+b=F(t)$ ,即 $b=F(t) - a*t$ 。可以发现, $b$ 的表达式就相当于给每个物品赋上额外的权值 $-t$ 时,dp所得的最优解。这样,我们只需要一次朴素的dp就可以求出截距。 而如何获取斜率呢?再一次,由于函数上凸,我们可以发现当斜率不断增大时,对应的切点横坐标也在不断左移。这样,我们就可以二分求相应的斜率了。每次check的时候顺便记录一下最优解选取了多少个,从而判断应该如何调整二分区间。 还有一个问题,有可能出现凸壳上多个点共线的情况。为了处理这种情况,我们可以限定在dp过程中,权值相同时优先取选择次数更小的转移。这样,我们求出来的次数就是当前切点的左边界,二分时判断一下就可以了。 回到这题,套用60(45)分做法的dp方式,去掉有关次数的限制,每次dp的复杂度就变成 $O(n)$ 。总复杂度为 $O(n \log k)$ 45分代码 ```cpp #include<bits/stdc++.h> #define reg register typedef long long ll; using namespace std; const int MN=3e5+5; int to[MN<<1],nxt[MN<<1],c[MN<<1],h[MN],cnt; inline void ins(int s,int t,int w){ to[++cnt]=t;nxt[cnt]=h[s];c[cnt]=w;h[s]=cnt; to[++cnt]=s;nxt[cnt]=h[t];c[cnt]=w;h[t]=cnt; } #define chkmax(a,b) ((a)<(b)?(a)=(b):0) int n,K; int f[MN][105][3]; void dfs(int st,int fa=0){ f[st][0][0]=f[st][0][1]=f[st][1][2]=0; for(reg int i=h[st];i;i=nxt[i]){ if(to[i]==fa)continue; dfs(to[i],st); for(reg int j=K;j;j--){ chkmax(f[st][j][1],f[st][j][0]+f[to[i]][0][1]+c[i]); for(reg int k=j-1;~k;k--){ chkmax(f[st][j][0],f[st][k][0]+f[to[i]][j-k][0]); chkmax(f[st][j][1],max(f[st][k][1]+f[to[i]][j-k][0],f[st][k][0]+f[to[i]][j-k][1]+c[i])); chkmax(f[st][j][2],max(f[st][k][2]+f[to[i]][j-k][0],f[st][k][1]+f[to[i]][j-k-1][1]+c[i])); } } chkmax(f[st][0][1],f[to[i]][0][1]+c[i]); } for(reg int i=1;i<=K;i++) chkmax(f[st][i][0],max(f[st][i-1][1],f[st][i][2])); } int main(){ scanf("%d%d",&n,&K);K++; for(reg int i=1,s,t,v;i<n;i++) scanf("%d%d%d",&s,&t,&v),ins(s,t,v); memset(f,~0x3f,sizeof(f));dfs(1); printf("%d\n",f[1][K][0]); return 0; } ``` 100分代码 ```cpp #include<bits/stdc++.h> #define reg register typedef long long ll; using namespace std; const int MN=3e5+5; int to[MN<<1],nxt[MN<<1],c[MN<<1],h[MN],cnt; inline void ins(int s,int t,int w){ to[++cnt]=t;nxt[cnt]=h[s];c[cnt]=w;h[s]=cnt; to[++cnt]=s;nxt[cnt]=h[t];c[cnt]=w;h[t]=cnt; } #define chkmax(a,b) ((a)<(b)?(a)=(b):0) int n,k; ll l,r,mid; struct data{ ll val;int pos; data(ll x=0,int y=0):val(x),pos(y){} friend bool operator<(data a,data b){ return a.val==b.val?a.pos>b.pos:a.val<b.val; } friend data operator+(data a,data b){ return data(a.val+b.val,a.pos+b.pos); } friend data operator+(data a,ll b){ return data(a.val+b,a.pos); } }f[MN][3],tmp; int fa[MN]; void getf(int st){ for(reg int i=h[st];i;i=nxt[i]) if(to[i]!=fa[st])fa[to[i]]=st,getf(to[i]); } void dfs(int st){ f[st][0]=f[st][1]=f[st][2]=data(); chkmax(f[st][2],tmp); for(reg int i=h[st];i;i=nxt[i]){ if(to[i]==fa[st])continue;dfs(to[i]); chkmax(f[st][2],max(f[st][2]+f[to[i]][0],f[st][1]+f[to[i]][1]+c[i]+tmp)); chkmax(f[st][1],max(f[st][1]+f[to[i]][0],f[st][0]+f[to[i]][1]+c[i])); chkmax(f[st][0],f[st][0]+f[to[i]][0]); } chkmax(f[st][0],max(f[st][1]+tmp,f[st][2])); } int main(){ scanf("%d%d",&n,&k);k++; for(reg int i=1,s,t,w;i<n;i++) scanf("%d%d%d",&s,&t,&w),ins(s,t,w); l=-1e12;r=1e12;getf(1); while(l<r){ mid=(double)(l+r)/2-0.5; tmp=data(-mid,1);dfs(1); if(f[1][0].pos==k){ printf("%lld\n",f[1][0].val+mid*k); return 0; } if(f[1][0].pos>k)l=mid+1; else r=mid; } mid=l;tmp=data(-mid,1);dfs(1); printf("%lld\n",f[1][0].val+mid*k); return 0; } ```