题解:P16435 [APIO 2026 中国赛区] 集宝

· · 题解

cz 老师下次不要再不关同步了呜呜呜。

定义 S(a,d) 表示距离点 a 不超过 d 的点集。

对于一次询问 (x,l,r),我们从 x 开始先走到 x_l \in S(a_l,d_l),下一步从 x_l 走到 x_{l+1} \in S(a_{l+1},d_{l+1})。可以发现,如果 S(a_l,d_l) \cap S(a_{l+1},d_{l+1}) \ne \varnothing,那么一定有 x_{l+1} \in S(a_l,d_l) \cap S(a_{l+1},d_{l+1})

而如果 S(a_l,d_l) \cap S(a_{l+1},d_{l+1}) = \varnothing,那么 x_{l+1} 的值就是确定的,和 x 无关。

也就是说,我们可以先找到 p 使得 \cap_{i=l}^p S(a_i,d_i) \ne \varnothing \wedge \cap_{i=l}^{p+1} S(a_i,d_i) = \varnothing,这样走到 S(a_{p+1},d_{p+1}) 之后每一步都和初始的 x 无关。

区间询问考虑分解到线段树上,线段树上每个区间 [l,r] 求出上述的 ps = \cap_{i=l}^p S(a_i,d_i)。那么从 x 经过 [l,r] 时只需从 x 走到 s,之后的贡献可以预处理计算。

现在考虑邻域求交,有结论是树上点邻域的交是点邻域或边邻域(即中心可能在点上或边上),在边上插入一点即可统一为点邻域方便计算。有结论后容易发现我们需要求 LCA 和 LA。前者使用 O(n \log n)-O(1) 的 dfn 序 st 表,后者使用 O(n)-O(\log n) 的同层 dfn 序二分,单次求交复杂度 O(\log n)

具体细节和做法详见代码,这里为方便总复杂度为 O((m+q) \log n \log m)。可以使用 O(n \log n)-O(1) 的长剖做到 O(n \log n + (m+q)\log m),但是常数更大跑得慢。精细实现也可以做到 O(m + q \log m)/O(m \log m + q)

卡常 | 长剖。复现的考场代码如下(QOJ 可过,洛谷 TLE):

::::info[code]

#include "gems.h"
#include <bits/stdc++.h>
using namespace std;

using LL = long long;
struct CIR {int a, d;};
constexpr int N = 6e5 + 5, M = 3e5 + 5, L = 20;

int m;
vector <int> gra[N];
CIR gem[M];
int dfn[N], dep[N], spt[L][N];
vector <int> lay[N];
struct DAT
{
    int pos;
    CIR pre;
    int las;
    LL sum;
}
dat[M << 2];

int cmp(int x, int y)
{
    return dep[x] < dep[y] ? x : y;
}

int lca(int x, int y)
{
    if (x == y) return x;
    x = dfn[x], y = dfn[y];
    if (x > y) swap(x, y);
    int k = 31 ^ __builtin_clz(y - x);
    return cmp(spt[k][y], spt[k][x + (1 << k)]);
}

int kth(int x, int k)
{
    k = dep[x] - k, x = dfn[x];
    int l = 0, r = lay[k].size() - 1, pos = 0;
    while (l <= r)
    {
        int mid = (l + r) >> 1;
        if (dfn[lay[k][mid]] <= x) l = (pos = mid) + 1;
        else r = mid - 1;
    }
    return lay[k][pos];
}

CIR operator+(const CIR &x, const CIR &y)
{
    int p = lca(x.a, y.a), d = dep[x.a] + dep[y.a] - dep[p] * 2;
    if (x.d + y.d < d) return {0, 0};
    if (x.d + d <= y.d) return x;
    if (y.d + d <= x.d) return y;
    int dd = (x.d + y.d - d) / 2;
    if (dep[x.a] - dep[p] >= x.d - dd) return {kth(x.a, x.d - dd), dd};
    return {kth(y.a, y.d - dd), dd};
}

CIR operator+(int x, const CIR &y)
{
    int p = lca(x, y.a), d = dep[x] + dep[y.a] - dep[p] * 2;
    if (d <= y.d) return {x, 0};
    if (dep[y.a] - dep[p] >= y.d) return {kth(y.a, y.d), d - y.d};
    return {kth(x, d - y.d), d - y.d};
}

int calc(int x, const CIR &y)
{
    return max(dep[x] + dep[y.a] - dep[lca(x, y.a)] * 2 - y.d, 0);
}

void dfs(int u, int f)
{
    dfn[u] = ++dfn[0], dep[u] = dep[f] + 1;
    spt[0][dfn[u]] = f, lay[dep[u]].push_back(u);
    for (int v : gra[u])
        if (v != f) dfs(v, u);
}

#define lp (p << 1)
#define rp (p << 1 | 1)
#define c ((l + r) >> 1)

void build(int p = 1, int l = 1, int r = m + 1)
{
    auto &[pos, pre, las, sum] = dat[p];
    pos = l + 1, pre = gem[l];
    while (pos < r)
    {
        CIR nxt = pre + gem[pos];
        if (!nxt.a) break;
        pre = nxt, ++pos;
    }
    if (pos < r)
    {
        las = (pre.a + gem[pos]).a;
        for (int i = pos + 1; i < r; i ++)
        {
            CIR nxt = las + gem[i];
            las = nxt.a, sum += nxt.d;
        }
    }
    if (r - l > 1) build(lp, l, c), build(rp, c, r);
}

LL divide(int &x, int L, int R, int p = 1, int l = 1, int r = m + 1)
{
    if (r <= L || l >= R) return 0;
    if (l >= L && r <= R)
    {
        CIR nxt = x + dat[p].pre;
        LL sum = nxt.d; x = nxt.a;
        if (dat[p].pos < r)
        {
            sum += calc(x, gem[dat[p].pos]) + dat[p].sum;
            x = dat[p].las;
        }
        return sum;
    }
    return divide(x, L, R, lp, l, c) + divide(x, L, R, rp, c, r);
}

#undef lp
#undef rp
#undef c

void gems(int c, int n, int _m, vector <int> _u, vector <int> _v, vector <int> _a, vector <int> _d)
{
    m = _m;
    for (int i = 0; i < n - 1; i ++)
    {
        gra[_u[i]].push_back(n + i + 1);
        gra[_v[i]].push_back(n + i + 1);
        gra[n + i + 1] = {_u[i], _v[i]};
    }
    for (int i = 0; i < m; i ++) gem[i + 1] = {_a[i], _d[i] * 2};

    dfs(1, 0);
    for (int i = 1; (1 << i) <= n * 2; i ++)
        for (int j = 1 << i; j <= n * 2; j ++)
            spt[i][j] = cmp(spt[i - 1][j], spt[i - 1][j - (1 << (i - 1))]);
    build();
}

LL query(int x, int l, int r)
{
    return divide(x, l, r + 1) / 2;
}

::::