题解:P16435 [APIO 2026 中国赛区] 集宝

· · 题解

前言

提供 APIO 讲题人的做法。

思路

首先我们发现,若 l\sim r 这些领域之交非空,我们发现直接将 x 移到交集内一定不劣。即跳到这个交集中距离 x 最近的点,如何求出这个点见下文。

那我们可以对 l\sim r 进行缩点,把相邻的有交的两个领域缩成一个他们的交集。

根据树上圆理论,在原树上的每一条边插入一个虚点,设 (x,y) 表示以 x 为中点,y 为半径的邻域,则 (x_1,y_1)\cap(x_2,y_2) 一定能表示成 (x_3,y_3)

::::info[如何在 O(\log n) 时间复杂度下求出 (x_3,y_3)]

首先特判掉交集为空(用 y<0 来表示此邻域为空集),以及两个邻域有互相包含的关系的两种情况。此时 y_3 一定是 \frac{y_1+y_2-\operatorname{dis}(x_1,x_2)}{2},而 x_3 就是 x_1\to x_2 这条路径上第 y_1-y_3 个点,这个用树上 k 级祖先求即可。

单次询问可以做到 O(\log n)

可以参考下图理解。

::::

缩完后设这些邻域为 (x_1,y_1),(x_2,y_2),\dots,(x_k,y_k)。考虑从 (x_i,y_i) 中的某个点跳入 (x_{i+1},y_{i+1}),发现我一定会经过 (x_i,y_i) 中距离 (x_{i+1},y_{i+1}) 最近的点 p_i。而 p_i 一定在 x_i,x_{i+1} 之间的路径上且距离在第一个领域的边界上,所以就是 x_1\to x_2 上的第 y_1 个点。

把路径简化为 x\to p_1\to p_2 \to \dots \to p_{k-1}\to pos,其中 pos(x_k,y_k) 中距离 p_{k-1} 中最近的点,把 p_{k-1} 视作一个半径为 0 的邻域,类似上面做法即可求出。

使用双指针,对于 x 先求出一个最大的 rt_y 表示 x\sim rt_y-1 的交集非空。显然区间 [l,r] 缩的段肯定是跳若干次 rt 构成的。则除去 x\to p_1p_{k-1}\to pos 的路径是可以通过 p_1r 确定的。

求出 p'_x 表示 x\sim rt_{x}-1 的交集到 rt_{x}\sim rt_{rt_x}-1 的交集最近的点。从 x 跳到 rt_x 的代价即为 \operatorname{dis}(p'_x,p'_{rt_x})

询问时先把答案初始化为 \operatorname{dis}(x,p'_l),然后从 l 开始跳一直跳到包含 r 那一段的上一段。倍增维护就行了。由于增加了虚点所以记得将答案除以二。

使用 ST 表求区间的交集,时间复杂度 O(n\log ^2 n+q\log n),可以通过。

代码

#include<bits/stdc++.h>
#define ll long long
#include "grader.cpp"
using namespace std;
const int N = 6e5+5;
int n,m,a[N],d[N];
vector<int> g[N];
int pre[N],dfn[N],idx,sz[N],son[N],dep[N],top[N],f[N];
void dfs1(int u,int fa)
{
    sz[u] = 1,dep[u] = dep[fa]+1,f[u] = fa;
    for(auto v:g[u])
    {
        if(v==fa) continue;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[son[u]]) son[u] = v;
    }
}
void dfs2(int u,int tp)
{
    pre[dfn[u] = ++idx] = u,top[u] = tp;
    if(!son[u]) return;
    dfs2(son[u],tp);
    for(auto v:g[u])
    {
        if(v==son[u]||v==f[u]) continue;
        dfs2(v,v);
    }
}
inline int lca(int x,int y)
{
    while(top[x]^top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x = f[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}
inline int jump(int x,int k)
{
    if(dep[x]<k) return 0;
    while(k>dep[x]-dep[top[x]])
    {
        k-=dep[x]-dep[top[x]]+1;
        x = f[top[x]];
    }
    return pre[dfn[x]-k];
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
inline pair<int,int> merge(pair<int,int> x,pair<int,int> y)
{
    if(x.second<0||y.second<0) return {1,-1};
    int d = dis(x.first,y.first);
    int res = (x.second+y.second-d)/2;
    if(res<0) return {1,-1};
    if(x.second<=res) return x;
    if(y.second<=res) return y;
    int p1 = jump(x.first,x.second-res); 
    int p2 = jump(y.first,y.second-res); 
    if(dep[p1]>dep[p2]) return {p1,res};
    else return {p2,res};
}
inline int work(pair<int,int> x,pair<int,int> y)
{
    int l = lca(x.first,y.first);
    if(dep[x.first]-dep[l]>=x.second) return jump(x.first,x.second);
    else return jump(y.first,dep[y.first]+dep[x.first]-2*dep[l]-x.second);
}
pair<int,int> st[20][N],tmp[N];
inline pair<int,int> get(int l,int r)
{
    int lg = __lg(r-l+1);
    return merge(st[lg][l],st[lg][r-(1<<lg)+1]);
}
int rt[N],pos[N],to[20][N];
ll sum[20][N];
void gems(int c,int n_,int m_,vector<int> u,vector<int> v,vector<int> a_,vector<int> d_)
{
    n = n_,m = m_;
    for(int i = 1,x,y;i<n;i++)
    {
        x = u[i-1],y = v[i-1];
        g[x].push_back(i+n);
        g[y].push_back(i+n);
        g[i+n].push_back(x);
        g[i+n].push_back(y);
    }
    for(int i = 1;i<=m;i++)
        a[i] = a_[i-1],d[i] = d_[i-1];
    dfs1(1,0),dfs2(1,1);
    for(int i = 1;i<=m;i++)
        st[0][i] = {a[i],d[i]*2};
    for(int j = 1;j<20;j++)
        for(int i = 1;i+(1<<j)-1<=m;i++)
            st[j][i] = merge(st[j-1][i],st[j-1][i+(1<<j-1)]);
    int p = 1;
    for(int i = 1;i<=m;i++)
    {
        p = max(p,i);
        while(p<=m&&get(i,p).second!=-1) p++;
        rt[i] = p,tmp[i] = get(i,p-1);
    }
    for(int i = 1;i<=m;i++)
    {
        if(rt[i]<=m) pos[i] = work(tmp[i],tmp[rt[i]]);
        else pos[i] = n;
    }
    rt[m+1] = m+1,pos[m+1] = n;
    for(int i = 1;i<=m+1;i++)
        to[0][i] = rt[i],sum[0][i] = dis(pos[i],pos[rt[i]]);
    for(int j = 1;j<20;j++)
        for(int i = 1;i<=m+1;i++)
            to[j][i] = to[j-1][to[j-1][i]],sum[j][i] = sum[j-1][i]+sum[j-1][to[j-1][i]];
}
ll query(int x,int l,int r)
{
    if(rt[l]>r)
    {
        auto tmp = get(l,r);
        if(dis(x,tmp.first)<=tmp.second) return 0; 
        return dis(x,work(get(l,r),{x,0}))/2;
    }
    ll ans = dis(x,pos[l]);x = pos[l];
    for(int i = 19;~i;i--)
        if(rt[to[i][l]]<=r)
            ans+=sum[i][l],x = pos[to[i][l]],l = to[i][l];
    ans+=dis(x,work(get(rt[l],r),{x,0}));
    return ans/2;
}