题解:P16435 [APIO 2026 中国赛区] 集宝

· · 题解

这个【数据删除】在考场上死磕一正场T3,最终拼尽全力写出 \log^2 然后死于大常数,荣获暴力分,喜提铁牌。

题目分析

题目给出的相当于是若干个树上的圆,我们要从一个点开始依次走到区间内每一个圆至少一次。

我们首先有一个贪心:我们如果走到一个圆的边界就能停的话我们一定会停。因为此时停下来一定比走到圆内其它点不劣。

我们考虑相邻两个圆之间的关系:

对于后者,我们有(广为人知的)树上圆交的结论:树上两圆相交,交集是一个圆,并且圆心只会在树上的点或者边的中点处取到。手玩一下可以理解。

为了方便维护我们可以在每一条边上插入一个结点,这样我们就只用考虑圆心在点上了。

于是我们发现,相交两圆可以使用一个等价的圆来代替。而且我们能够计算出这个圆。

我们考虑如何处理区间询问。考虑使用线段树, 我们发现有两类数据:

  1. 前文中相离的圆,会坍缩成一个点,而且后面都是这一个点。
  2. 相交的圆组成的等价圆。

对于点我们记录开始和结束的点与中间走过的距离,对于圆我们记录圆心与半径。

我们考虑合并:

  1. 两个圆合并
    • 若两圆相交,计算出等价圆
    • 若相切或相离,则分别计算出两个圆与连心线的交点,前者作为起点,后者作为终点,计算中间的距离。
  2. 圆与点的合并
    • 根据情况合并开始结点或结束结点把点可以看做半径为 0 的圆。
  3. 点与点的合并
    • 将上一个点的结束结点与下一个点的开始结点距离算出,求和即可。

计算连心线与圆的交点可以使用树上 k 级祖先的做法,时间复杂度为 O(\log n)

那么就做完了,时间复杂度是 O(q\log m\log n)

因为是 \log^2 的做法,所以需要注意精细实现,说几个要点:

  1. 使用 O(1) lca 会显著减少常数。
  2. 尽可能利用已经求出的 lca
  3. 优化内存访问。

code

代码不长,只有不到 4k,在 QOJ 上荣获目前最短解。 :::success[在QOJ上通过]

#include <bits/stdc++.h>
#include "gems.h"
#define _F(x,y,z) for(int x=y;x<=z;x++)
#define F_(x,y,z) for(int x=y;x>=z;x--)
#define TF(x,y,z) for(int x=head[y],z;x;x=nex[x])

using namespace std;

typedef long long ll;
typedef const int ci;
typedef pair<int,int> pii;
typedef double dou;

ci maxn=6e5+10,p=1e9+7;
int n,m,fa[22][maxn],dep[maxn],st[22][maxn],dfn[maxn],di[maxn],cnt,lg[maxn];
vector<int> e[maxn],a,d;
inline void dfs(int x,int f)
{
    fa[0][x]=f;dep[x]=dep[f]+1;
    dfn[x]=++cnt,di[cnt]=x;
    st[0][cnt]=x;
    _F(i,1,21)
        fa[i][x]=fa[i-1][fa[i-1][x]];
    for(int y:e[x])
    {
        if(y!=f)
            dfs(y,x);
    }
}
inline void init()
{
    lg[0]=-1;
    _F(i,1,maxn-10)
        lg[i]=lg[i>>1]+1;
    for(int i=1;(1<<i)<=n;i++)
    {
        for(int j=1;j+(1<<(i))-1<=n;j++)
        {
            int x=st[i-1][j],y=st[i-1][j+(1<<(i-1))];
            if(dep[x]>dep[y])
                st[i][j]=y;
            else
                st[i][j]=x;
        }
    }
}
inline int lca(int x,int y)
{
    if(x==y) return x;
    x=dfn[x],y=dfn[y];
    if(x>y)swap(x,y);
    x++;
    int len=lg[y-x+1],xx=st[len][x],yy=st[len][y-(1<<len)+1];
    if(dep[xx]>dep[yy])
        return fa[0][yy];
    return fa[0][xx];
}
inline int dis(int x,int y,int l=-1)
{
    l=(l<0)?lca(x,y):l;
    return dep[x]+dep[y]-2*dep[l];
}
inline int getdis(int x,int y,int d,int l=-1)
{
    l=(l<0)?lca(x,y):l;
    if(d>dep[x]-dep[l])
        return getdis(y,x,dis(x,y)-d,l);
    while(d)
        x=fa[lg[(d&-d)]][x],d-=d&-d;
    return x;
}
#define lson (now<<1)
#define rson (now<<1|1)
#define mid ((l+r)>>1)
struct dat 
{
    ll dis;
    int tp,c1,r1,st,ed;
}tr[maxn<<2];
dat operator +(dat x,dat y)
{
    if(x.tp&&y.tp)
        return (dat){x.dis+y.dis+dis(x.ed,y.st),1,0,0,x.st,y.ed};
    int cx=x.tp?x.ed:x.c1,cy=y.tp?y.st:y.c1,rx=x.r1,ry=y.r1;
    int l=lca(cx,cy),ds=dis(cx,cy,l);
    if(!(x.tp||y.tp))
    {
        if(ds>rx+ry)
        {
            int st=getdis(cx,cy,rx,l),ed=getdis(cy,cx,ry,l);
            return (dat){dis(st,ed),1,0,0,st,ed};
        }
        int lll=max(-rx,ds-ry),rrr=min(rx,ds+ry);
        int md=(lll+rrr)>>1;
        int nc=getdis(cx,cy,md,l);
        return (dat){0,0,nc,md-lll,0,0};
    }
    if(x.tp)
    {
        if(ds>ry)
        {
            int ed=getdis(cy,cx,ry,l);
            return (dat){x.dis+dis(cx,ed),1,0,0,x.st,ed};
        }
        return (dat){x.dis,1,0,0,x.st,x.ed};
    }
    if(y.tp)
    {
        if(ds>rx)
        {
            int st=getdis(cx,cy,rx,l);
            return (dat){dis(st,cy)+y.dis,1,0,0,st,y.ed};
        }
        return (dat){y.dis,1,0,0,y.st,y.ed};
    }
}
void build(int now,int l,int r)
{
    if(l==r)
    {
        tr[now]=(dat){0,0,a[l-1],2*d[l-1],0,0};
        return ;
    }
    build(lson,l,mid);
    build(rson,mid+1,r);
    tr[now]=tr[lson]+tr[rson];
}
void ask(int now,int l,int r,int x,int y,dat &ans)
{   
    if(x<=l&&r<=y)
    {
        ans=ans+tr[now];
        return ;
    }
    if(x<=mid)
        ask(lson,l,mid,x,y,ans);
    if(y>mid)   
        ask(rson,mid+1,r,x,y,ans);
}
void gems(int c, int _n, int _m, std::vector<int> u, std::vector<int> v, std::vector<int> _a, std::vector<int> _d) 
{
    n=_n,m=_m;
    _F(i,0,_n-2)
    {
        int x=u[i],y=v[i];
        e[x].push_back(i+n+1);
        e[y].push_back(i+n+1);
        e[i+n+1].push_back(x);
        e[i+n+1].push_back(y);
    }
    n=2*_n-1;
    dfs(1,0);
    init();
    a=_a,d=_d;
    build(1,1,m);
    return;
}

long long query(int x, int l, int r) {
    dat ans=(dat){0,1,0,0,x,x};
    ask(1,1,m,l,r,ans);
    return ans.dis/2;
}

::: 图糙轻喷