题解:P16435 [APIO 2026 中国赛区] 集宝
前言
提供 APIO 讲题人的做法。
思路
首先我们发现,若
那我们可以对
根据树上圆理论,在原树上的每一条边插入一个虚点,设
::::info[如何在
首先特判掉交集为空(用
单次询问可以做到
可以参考下图理解。
::::
缩完后设这些邻域为
把路径简化为
使用双指针,对于
求出
询问时先把答案初始化为
使用 ST 表求区间的交集,时间复杂度
代码
#include<bits/stdc++.h>
#define ll long long
#include "grader.cpp"
using namespace std;
const int N = 6e5+5;
int n,m,a[N],d[N];
vector<int> g[N];
int pre[N],dfn[N],idx,sz[N],son[N],dep[N],top[N],f[N];
void dfs1(int u,int fa)
{
sz[u] = 1,dep[u] = dep[fa]+1,f[u] = fa;
for(auto v:g[u])
{
if(v==fa) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u] = v;
}
}
void dfs2(int u,int tp)
{
pre[dfn[u] = ++idx] = u,top[u] = tp;
if(!son[u]) return;
dfs2(son[u],tp);
for(auto v:g[u])
{
if(v==son[u]||v==f[u]) continue;
dfs2(v,v);
}
}
inline int lca(int x,int y)
{
while(top[x]^top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x = f[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
inline int jump(int x,int k)
{
if(dep[x]<k) return 0;
while(k>dep[x]-dep[top[x]])
{
k-=dep[x]-dep[top[x]]+1;
x = f[top[x]];
}
return pre[dfn[x]-k];
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
inline pair<int,int> merge(pair<int,int> x,pair<int,int> y)
{
if(x.second<0||y.second<0) return {1,-1};
int d = dis(x.first,y.first);
int res = (x.second+y.second-d)/2;
if(res<0) return {1,-1};
if(x.second<=res) return x;
if(y.second<=res) return y;
int p1 = jump(x.first,x.second-res);
int p2 = jump(y.first,y.second-res);
if(dep[p1]>dep[p2]) return {p1,res};
else return {p2,res};
}
inline int work(pair<int,int> x,pair<int,int> y)
{
int l = lca(x.first,y.first);
if(dep[x.first]-dep[l]>=x.second) return jump(x.first,x.second);
else return jump(y.first,dep[y.first]+dep[x.first]-2*dep[l]-x.second);
}
pair<int,int> st[20][N],tmp[N];
inline pair<int,int> get(int l,int r)
{
int lg = __lg(r-l+1);
return merge(st[lg][l],st[lg][r-(1<<lg)+1]);
}
int rt[N],pos[N],to[20][N];
ll sum[20][N];
void gems(int c,int n_,int m_,vector<int> u,vector<int> v,vector<int> a_,vector<int> d_)
{
n = n_,m = m_;
for(int i = 1,x,y;i<n;i++)
{
x = u[i-1],y = v[i-1];
g[x].push_back(i+n);
g[y].push_back(i+n);
g[i+n].push_back(x);
g[i+n].push_back(y);
}
for(int i = 1;i<=m;i++)
a[i] = a_[i-1],d[i] = d_[i-1];
dfs1(1,0),dfs2(1,1);
for(int i = 1;i<=m;i++)
st[0][i] = {a[i],d[i]*2};
for(int j = 1;j<20;j++)
for(int i = 1;i+(1<<j)-1<=m;i++)
st[j][i] = merge(st[j-1][i],st[j-1][i+(1<<j-1)]);
int p = 1;
for(int i = 1;i<=m;i++)
{
p = max(p,i);
while(p<=m&&get(i,p).second!=-1) p++;
rt[i] = p,tmp[i] = get(i,p-1);
}
for(int i = 1;i<=m;i++)
{
if(rt[i]<=m) pos[i] = work(tmp[i],tmp[rt[i]]);
else pos[i] = n;
}
rt[m+1] = m+1,pos[m+1] = n;
for(int i = 1;i<=m+1;i++)
to[0][i] = rt[i],sum[0][i] = dis(pos[i],pos[rt[i]]);
for(int j = 1;j<20;j++)
for(int i = 1;i<=m+1;i++)
to[j][i] = to[j-1][to[j-1][i]],sum[j][i] = sum[j-1][i]+sum[j-1][to[j-1][i]];
}
ll query(int x,int l,int r)
{
if(rt[l]>r)
{
auto tmp = get(l,r);
if(dis(x,tmp.first)<=tmp.second) return 0;
return dis(x,work(get(l,r),{x,0}))/2;
}
ll ans = dis(x,pos[l]);x = pos[l];
for(int i = 19;~i;i--)
if(rt[to[i][l]]<=r)
ans+=sum[i][l],x = pos[to[i][l]],l = to[i][l];
ans+=dis(x,work(get(rt[l],r),{x,0}));
return ans/2;
}