P9755 [CSP-S 2023] 种树 题解

· · 题解

本题解将通过链和特殊性质 A 两档部分分引导读者走向正解。请各位不要跳读。

我们首先来考虑链的部分分。

不难发现,这部分的关键是实现一个形如 check(x,l,r) 的函数,表示 x 结点在 [l,r] 这段时间的生长高度能否达成目标。同时,这也是正解的一个关键函数。

不妨列出 x 结点高度的表达式:

\sum_{i=l}^{r}\max(1,b_x + ic_x)

先考虑 c_x \geq 0 的情况,那么上式可按如下化简:

\sum_{i=l}^{r}b_x + ic_x =\sum_{i=l}^{r}b_x+\sum_{i=l}^{r}ic_x =b_x(r-l+1)+c_x\frac{(l+r)(r-l+1)}{2}

接下来考虑 c_x<0,此时我们不妨找到使得 b_x + ic_x \geq 1 的最大的 i,不难得出:i_{max} = \lfloor \frac{1-b_x}{c_x} \rfloor 。接下来分情况讨论,可得:

& i_{max} < l \\ b_x(r-l+1)+c_x\frac{(l+r)(r-l+1)}{2} & i_{max}> r \\b_x(i_{max}-l+1)+c_x\frac{(l+i_{max})(i_{max}-l+1)}{2}+r-i_{max} & i_{max} \in [l,r] \end{cases}

至此这个函数实现完毕。而统计答案,你就对链上的第 i 个点二分出最小的 r 满足 check(i,i,r) 为真的 r,所有 r 取个 \max 即可。

特殊性质 A

接着我们思考特殊性质 A。

不难发现此时每个结点生长到目标高度所需时间与种下的时间无关,即有 t_x = \lceil \frac{a_x}{b_x} \rceil

将结点按 t_x 从大到小排序。考虑这样一个贪心,我们顺次考虑所有结点,若该结点已被种树就跳过,否则将根到该结点这条链按顺序把未种树的结点种树。显然这样做会得到一个种树顺序的序列,显然由于 t_x 更小的结点不是时间的瓶颈,我们这样做是最优的。

实际实现过程中,我们标记一下每个结点是否种过树。当考虑当前结点 x 时,暴力跳到最后一个未被标记的祖先结点,然后倒序再给每个结点赋上开始种树的时间 s_x。取 s_x+t_x-1 的最大值即可。

正解

想明白了前两个部分,正解就是容易的。

考虑一般情况与特殊性质 A 的区别,不难发现我们无法得到上述的 t_x 了,究其根本是因为树种下的时间会影响种树需要的时间。

我们考虑二分答案,并修改 t_x 的定义为要使该结点合法的最晚种树时间。显然在该前提下,t_x 可以用二分答案加上链部分所实现的函数求出。而这时我们按照 t_x 从小到大考虑,不难发现问题就转化为了特殊性质 A 时的问题,套用以上做法即可解决。

时间复杂度 O(n \log n \log v)。用桶排并上二次函数相关知识应该能将 \log n 去掉,不过没必要。

代码:

#include<bits/stdc++.h>
using namespace std;
#define ___ __int128
const int N=1e5+10;
int n,b[N],c[N],p[N],t[N],fa[N],stk[N];
int h[N],e[N<<1],ne[N<<1],idx;
bool vis[N];
long long a[N];
inline void add(int a,int b)
{
    e[idx]=b;ne[idx]=h[a];h[a]=idx++;
}
inline void dfs(int u,int p){fa[u]=p;for(int i=h[u];~i;i=ne[i]) if(e[i]!=p) dfs(e[i],u);}
inline ___ calc(int x,___ l,___ r)
{
    if(c[x]>=0) return (r-l+1)*b[x]+(r-l+1)*(l+r)/2*c[x];
    ___ T=(1-b[x])/c[x];
    if(T<l) return r-l+1;
    if(T>r) return (r-l+1)*b[x]+(r-l+1)*(l+r)/2*c[x];
    return (T-l+1)*b[x]+(T-l+1)*(l+T)/2*c[x]+r-T;
}
inline bool check(int r)
{
    for(int i=1;i<=n;i++) 
    {
        if(calc(i,1,r)<a[i]) return false;
        int dl=1,dr=n;
        while(dl<dr)
        {
            int mid=(dl+dr+1)>>1;
            if(calc(i,mid,r)>=a[i]) dl=mid;
            else dr=mid-1;
        }
        p[i]=i;t[i]=dl;vis[i]=false;
    }
    sort(p+1,p+n+1,[](int A,int B){return t[A]<t[B];});
    for(int i=1,x=0;i<=n;i++)
    {
        int now=p[i],top=0;
        while(!vis[now]) vis[stk[++top]=now]=true,now=fa[now];
        while(top) if(t[stk[top--]]<++x) return false;
    }
    return true;
}
int main()
{
    memset(h,-1,sizeof h);
    scanf("%d",&n);
    for(int i=1;i<=n;i++) scanf("%lld%d%d",&a[i],&b[i],&c[i]);
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        add(u,v);add(v,u);
    }
    dfs(1,0);vis[0]=true;
    int l=n,r=1e9;
    while(l<r)
    {
        int mid=(l+r)>>1;
        if(check(mid)) r=mid;
        else l=mid+1;
    } 
    printf("%d",l);
    return 0;
}