题解:P16435 [APIO 2026 中国赛区] 集宝

· · 题解

优秀的 DS。

n,m,q \le 10^3

会思考一个贪心,就是每次直接向下一个圆走,走到圆上就停止,然后思考下一个圆。这个贪心是正确的,要用到 \operatorname{LCA} 和树上 k 级祖先。

::::warning[证明] 显然如果走到圆上后再往其他位置走是不优的,可以放到下一步移动再走。 ::::

n,m \le 10^3 , q \le 3 \times 10^5

把求 和树上 k 级祖先改成 O(1) 的即可。前者使用欧拉序,后者用长剖

如果只想要这一部分分的话,也可以先预处理任意两个点的 \operatorname{LCA} 和任意一个点的任意祖先。

特殊性质 A

因为是一条链,所以树上的圆就是区间,直接上线段树即可。

特殊性质 B

要一点注意力,我们会发现,树上所有圆一定会在重心处有交。现在,我们证明一个定理,树上任意两个圆,如果有交,则交为一个圆。

::::warning[证明]

我们记一个圆为 (u,r) 表示所有 dis(u,v) \le r 的点集。

d = \operatorname{dis}(u_1, u_2),R = \frac{r_1+r_2-d}{2}。 取 u_1,u_2 路径上距离 u_1r_1-R 的点 A,则圆 (u_1,r_1),(u_2,r_2) 的交为 (A,R)

证明 (u_1,r_1)\cap(u_2,r_2) \subseteq (A,R)

对于任意一个点 v,满足 \operatorname{dis}(u_1,v) \le r_1,\operatorname{dis}(u_2,v) \le r_2。此时 \operatorname{dis}(v,A) = \max(\operatorname{dis}(v,u_1)+R-r_1,\operatorname{dis}(v,u_2)+R-r_2),发现前后两个式子都 \le R,所以 \operatorname{dis}(v,A) \le R,即 v \in (A,R)

证明 (A,R) \subseteq (u_1,r_1)\cap(u_2,r_2)

对于任意一个 v,满足 \operatorname{dis}(v,A) \le R,首先我们会注意到 \operatorname{dis}(v,u_1) \le \operatorname{dis}(v,A)+\operatorname{dis}(A,u_1) = \operatorname{dis}(v,A) + r_1-R \le R + r_1 - R = r_1,所以 \operatorname{dis}(v,u_1) \le r_1(u_2,r_2) 同理。 ::::

同时我们发现,如果要从点 x,依次经过若干个圆,最后一定会停在这些圆的交上,因为如果最后不在交上,说明一定离开过某个圆,这是不优的。所以我们可以直接用线段树对圆求交,然后直接判断点到圆的距离即可。

特殊性质 C

现在我们知道了如果所有圆有交,那么一定会直接走到交上。现在我们来思考如果某些圆没有交。

我们一定可以先找到一个最大的 p,满足前 p 个圆是有交的,然后走到这些圆的交上。接下来我们会发现,从一个圆走到另一个没有交的圆上,在圆外的路径是唯一的,所以我们可以预处理出来从第 p+1 个圆开始的路径,然后记录。由于 l = 1,我们直接从左往右模拟即可。

n,m,q \le 3\times10^5

最后的一部分。我们现在思考怎么求 x \to [l,r] 的最短路。找到最大的 p 满足 [l,p] 的圆是有交的,先从 x 走到圆上,然后剩下的路径就是确定的。于是我们会想到用线段树,首先把区间拆成 \log 个,然后对于一个区间,我们预处理出这个区间的 p,前缀的圆的交,后续的路径长,和最后走到的位置。这些可以在 build 的时候 O(m\log m) 预处理,查询的时候模拟即可。时间复杂度就是 O(n+\log n + (m+q) \log m),其中的 O(n\log n) 是预处理欧拉序和 k 级祖先用的。

::::info[Code(QOJ 可过)]

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define il inline
const int N = 600005;
int n, m, dep[N], mxd[N], fa[N][25], son[N], top[N], lg[N << 1], E[N << 1][25], dfn[N], ex;
int *f[N], *g[N], buf[N << 1], *now;
vector <int> e[N];
#define get(u, v) ( dep[u] < dep[v] ? u : v )
void dfs1(int u){
    mxd[u] = dep[u] = dep[fa[u][0]] + 1, son[u] = 0, E[++ex][0] = u, dfn[u] = ex;
    for (int v : e[u]) if (v != fa[u][0])
        fa[v][0] = u, dfs1(v), mxd[u] = max(mxd[u], mxd[v]), son[u] = (mxd[son[u]] > mxd[v] ? son[u] : v), E[++ex][0] = u;
}
void dfs2(int u, int tp){
    top[u] = tp;
    if (u == tp){
        f[u] = now, now += mxd[u] - dep[u] + 1, g[u] = now, now += mxd[u] - dep[u] + 1;
        for (int i = 0, x = u, y = u; i <= mxd[u] - dep[u]; i++, x = fa[x][0], y = son[y]) f[u][i] = x, g[u][i] = y;
    }
    if (son[u]) dfs2(son[u], tp);
    for (int v : e[u]) if (v != son[u] && v != fa[u][0]) dfs2(v, v);
}
il int kth(int u, int k){
    if (!k) return u;
    u = fa[u][lg[k]], k -= (1 << lg[k]), k -= dep[u] - dep[top[u]], u = top[u];
    return k >= 0 ? f[u][k] : g[u][-k];
}
il int Lca(int u, int v){
    int l = dfn[u], r = dfn[v]; if (l > r) swap(l, r);
    int p = lg[r - l + 1]; return get(E[l][p], E[r - (1 << p) + 1][p]);
}
#define dis(u, v) ( dep[u] + dep[v] - (dep[Lca(u, v)] << 1) )
struct Node{ int u, r; }a[N];
struct Tree{ int p, u; ll sum; Node A; }tr[N << 2];
Node operator +(const Node &a, const Node&b){
    int L = Lca(a.u, b.u), d = dep[a.u] + dep[b.u] - (dep[L] << 1);
    if (a.r + b.r < d) return Node{0, 0};
    if (a.r + d <= b.r) return a;
    if (b.r + d <= a.r) return b;
    int R = (a.r + b.r - d) >> 1;
    if (dep[a.u] - dep[L] >= a.r - R) return Node{kth(a.u, a.r - R), R};
    return Node{kth(b.u, b.r - R), R};
}
Node operator +(int u, const Node&a){
    int L = Lca(u, a.u), d = dep[u] + dep[a.u] - (dep[L] << 1);
    if (d <= a.r) return Node{u, 0};
    if (dep[a.u] - dep[L] >= a.r) return Node{kth(a.u, a.r), d - a.r};
    return Node{kth(u, d - a.r), d - a.r};
}
void build(int p, int l, int r){
    tr[p].p = l + 1, tr[p].A = a[l];
    while (tr[p].p <= r){
        Node C = tr[p].A + a[tr[p].p];
        if (!C.u) break; else tr[p].A = C, tr[p].p++;
    }
    if (tr[p].p <= r){
        tr[p].u = (tr[p].A.u + a[tr[p].p]).u;
        for (int i = tr[p].p + 1; i <= r; i++){
            Node C = tr[p].u + a[i];
            tr[p].u = C.u, tr[p].sum += C.r;
        }
    }
    if (l == r) return ;
    int mid = (l + r) >> 1;
    build(p << 1, l, mid), build(p << 1 | 1, mid + 1, r);
}
ll ask(int p, int l, int r, int nl, int nr, int &x){
    ll ans = 0;
    if (nl <= l && r <= nr){
        Node C = x + tr[p].A;
        x = C.u, ans += C.r;
        if (tr[p].p <= r) ans += max(dis(x, a[tr[p].p].u) - a[tr[p].p].r, 0) + tr[p].sum, x = tr[p].u;
        return ans;
    }
    int mid = (l + r) >> 1;
    if (nl <= mid) ans += ask(p << 1, l, mid, nl, nr, x);
    if (nr > mid) ans += ask(p << 1 | 1, mid + 1, r, nl, nr, x);
    return ans;
}
void gems(int cid, int n, int m, vector<int> _u, vector<int> _v, vector<int> _a, vector<int> _d){
    ::n = n, ::m = m, lg[0] = -1, now = buf;
    for (int i = 1; i < n * 4; i++) lg[i] = lg[i >> 1] + 1;
    for (int i = 1; i <= m; i++) a[i] = Node{_a[i - 1], _d[i - 1] * 2};
    for (int i = 0; i < n - 1; i++) e[_u[i]].push_back(n + i + 1), e[_v[i]].push_back(n + i + 1), e[n + i + 1] = {_u[i], _v[i]};
    n += n - 1;
    dfs1(1), dfs2(1, 1);
    for (int j = 1; j <= 21; j++) for (int i = 1; i <= n; i++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
    for (int j = 1; j <= 21; j++) for (int i = 1; i + (1 << j) - 1 <= ex; i++) E[i][j] = get(E[i][j - 1], E[i + (1 << (j - 1))][j - 1]);
    build(1, 1, m);
}

ll query(int x, int l, int r){ return ask(1, 1, m, l, r, x) / 2; }

::::

但是在洛谷过不了,我们来优化一下。同样的我们注意到,我们存储前缀圆的交,仅仅只是确定 x 往哪里走。但是实际上我们会发现,我们连一开始的圆都不用存,可以直接记录下面这个紫色点。

我们完全可以视为我们直接走到紫色点,然后走后面的一段路径。也就是说,一段区间合并之后,要么是一个圆,要么是一条路径。这样子就避免了存储前缀圆的交。

然后实现上,树剖求 \operatorname{LCA} 比欧拉序快,倍增求树上 k 级祖先比长剖快,同时线段树可以改成 ST 表。这样子写下来只有 2.6 K,还跑得飞快。

::::info[Code(QOJ 和洛谷均可过)]

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define il inline
il int rd(){
    int s = 0, w = 1;
    char ch = getchar();
    for (;ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
    for (;ch >= '0' && ch <= '9'; ch = getchar()) s = ((s << 1) + (s << 3) + ch - '0');
    return s * w;
}
const int N = 600005, M = 300005;
int n, m, dep[N], sz[N], fa[N][25], son[N], top[N], dfn[N], ex;
vector <int> e[N];
void dfs1(int u){
    dep[u] = dep[fa[u][0]] + (sz[u] = 1);
    for (int v : e[u]) if (v != fa[u][0]){
        fa[v][0] = u;
        for (int j = 0; fa[v][j]; j++) fa[v][j + 1] = fa[fa[v][j]][j];
        dfs1(v), sz[u] += sz[v], son[u] = (sz[son[u]] > sz[v] ? son[u] : v);
    }
}
void dfs2(int u, int tp){
    top[u] = tp, dfn[u] = ++ex;
    if (son[u]) dfs2(son[u], tp);
    for (int v : e[u]) if (v != son[u] && v != fa[u][0]) dfs2(v, v);
}
il int To(int u, int v, int L, int d){
    if (dep[u] - dep[L] < d) d = dep[u] + dep[v] - (dep[L] << 1) - d, u = v;
    for (int j; d;) j = __builtin_ctz(d), d -= (1 << j), u = fa[u][j];
    return u;
}
il int Lca(int u, int v){
    for (; top[u] != top[v]; u = fa[top[u]][0]) if (dep[top[u]] < dep[top[v]]) swap(u, v);
    return dep[u] < dep[v] ? u : v;
}
struct Node{ int u, v, r; ll sum; }F[M][21];
il Node operator +(const Node &a, const Node &b){
    Node res = Node{0, 0, 0, 0}; res.sum = a.sum + b.sum;
    int L = Lca(a.v, b.u), d = dep[a.v] + dep[b.u] - (dep[L] << 1);
    if (d >= a.r + b.r){
        res.sum += d - a.r - b.r;
        if (!a.sum) res.u = To(a.v, b.u, L, a.r);
        else res.u = a.u;
        if (!b.sum) res.v = To(b.u, a.v, L, b.r);
        else res.v = b.v;
    }
    else if (a.r + d <= b.r) res.u = a.u, res.v = a.v, res.r = a.r;
    else if (b.r + d <= a.r) res.u = b.u, res.v = b.v, res.r = b.r;
    else res.r = (a.r + b.r - d) >> 1, res.u = res.v = To(a.u, b.u, L, a.r - res.r);
    return res;
}
void gems(int cid, int n, int m, vector<int> _u, vector<int> _v, vector<int> _a, vector<int> _d){
    ::n = n, ::m = m;
    for (int i = 0; i < n - 1; i++) e[_u[i]].push_back(n + i + 1), e[_v[i]].push_back(n + i + 1), e[n + i + 1] = {_u[i], _v[i]};
    for (int i = 1; i <= m; i++) F[i][0] = Node{_a[i - 1], _a[i - 1], _d[i - 1] * 2, 0};
    n += n - 1, dfs1(1), dfs2(1, 1);
    for (int j = 1; j <= 21; j++) for (int i = 1; i + (1 << j) - 1 <= m; i++) F[i][j] = F[i][j - 1] + F[i + (1 << (j - 1))][j - 1];
}
ll query(int x, int l, int r){
    Node res = Node{x, x, 0, 0};
    for (int j = 21; ~j; j--) if (l + (1 << j) - 1 <= r) res = res + F[l][j], l += (1 << j);
    return res.sum >> 1;
}

::::