题解 P6615 【草莓 Strawberry】

· · 题解

怎么是黄题啊zbl,看来我已经弱到连黄题都不会了。。

感觉是一道比较好的构造题。

我们先分析一下题意吧。有两个集合,一个存着你选过的点,另一个存着你选点过程中的两点间路径的最小值。让你最大化所有选的点权 + 最小值的乘积。

值得注意的是,选过的点不论你选多少次都算一次,不代表你不可以选多个点。

好了,看了这个题,我一点思路都没有,于是开始看subtask。

直接输出0就可以得到1分。这就是我比赛的分数。

首先看链的部分分。先说一个结论:一定选的是这个链的一段前缀和一段后缀。

我们从左到右编号,从1~5。

首先考虑路径最小值怎么算。那显然是1~4的长度和2~5的长度的最小值。

这个可以这样想:我可以按照1、5、1、4、1、5、2这样的顺序去选。这样就保证了最大化距离的最小值。

然后如果选的不是一段前缀和一段后缀,那显然不如选上。因为这样最小距离是不变的,而点权变大了。

好了,链的部分就这些。

如果是树的话,类比链的情况,可以把直径找出来。像这样的图:

找到直径之后,假设那个红色的是直径,我们可以类比链的方法,每次构造选点的顺序,都是考虑走到距离你想选的点最远的那个点,然后再选那个点,然后再跳回距离你选的点最远的点。

这样每个点对于路径最小值的贡献,就是离这个点最远的点与这个点的距离。所以我们可以按照原来的思路,把每个点和距离他距离最远的点之间的距离算出来排序,最终统计答案。

显然距离每个点最远的点是直径两个端点中的一个。

这样就可以直接统计答案了,具体看代码82~91行。

复杂度是O(n log n),瓶颈在于排序。

#include<cstdio>
#include<algorithm>
#include<cstring>
#define int long long
using std::max;
using std::min;
using std::sort;
const int maxn = 300005;
struct stu{
    int to,nxt,w;
}sexy[600005];
long long ans;
int n,hd[maxn],cnt,a[maxn];
void add(int x,int y,int z)
{
    sexy[++cnt].to = y;
    sexy[cnt].nxt = hd[x];
    hd[x] = cnt;
    sexy[cnt].w = z;
}
int dep1[maxn],dep2[maxn],rt1,rt2,dep[maxn];
void dfs1(int x,int fa){
    for(int i=hd[x];i;i=sexy[i].nxt){
        int v = sexy[i].to;
        if(v == fa) continue;
        dep[v] = dep[x] + sexy[i].w;
        dfs1(v,x);
    }
}
struct akk{
    int as,bs;
}chz[maxn];
bool xtt(akk a,akk b){
    return a.as > b.as;
}
void dfs2(int x,int fa){
    for(int i=hd[x];i;i=sexy[i].nxt){
        int v = sexy[i].to;
        if(v == fa) continue;
        dep1[v] = dep1[x] + sexy[i].w;
        dfs2(v,x);
    }
}
int zhijing; 
void dfs3(int x,int fa){
    for(int i=hd[x];i;i=sexy[i].nxt){
        int v = sexy[i].to;
        if(v == fa) continue;
        dep2[v] = dep2[x] + sexy[i].w;
        dfs3(v,x);
    }
}
long long sum1[maxn],sum2[maxn]; 
signed main()
{
    scanf("%lld",&n);
    for(int i=1;i<=n;i++)
        scanf("%lld",&a[i]);
    for(int i=1;i<n;i++){
        int x,y,z;
        scanf("%lld%lld%lld",&x,&y,&z);
        add(x,y,z);
        add(y,x,z);
    } 
    dfs1(1,0);
    int maxx = 0; 
    for(int i=1;i<=n;i++){
        if(maxx < dep[i]){
            maxx = dep[i];
            rt1 = i;
        }
    }
    dfs2(rt1,0);
    maxx = 0;
    for(int i=1;i<=n;i++){
        if(maxx < dep1[i]){
            maxx = dep1[i];
            rt2 = i;
        }
    }
    dfs3(rt2,0); 
    for(int i=1;i<=n;i++){
        chz[i].as = max(dep1[i],dep2[i]);
        chz[i].bs = a[i];
    }
    sort(chz + 1,chz + n + 1,xtt);
    maxx = 0;
    for(int i=1;i<=n;i++){
        maxx += chz[i].bs;
        ans = max(ans,maxx * chz[i].as);
    }
    printf("%lld\n",ans);
}