P14312 【模板】K-D Tree の题解

· · 题解

kdt 略解

简介:本文主要介绍 kdt 的结构体封装和替罪羊维护写法。

其实我都不会替罪羊,只会替罪羊维护的 kdt,同时更新了复杂度证明。

题意

本题需要支持的操作:

  1. 插点;
  2. 矩形(立方体)加;
  3. 矩形(立方体)求和。

然后卡空间,时限很松。

这不就是 kdt 板题吗。

算法介绍

应用

可以高效支持在线高维矩形查询问题。

也可以通过剪枝解决一些平面点对问题。

原理

以下只讨论二维情况,三维请自行类比。

kdt 具有二叉搜索树的形态,二叉搜索树人话就是中序遍历是有序的带权二叉树。

既然中序遍历是有序的可以搜索,我们可以通过每次选中位数做根,建出一棵树,例如:

before: (0,4) (2,0) (1,3) (2,4) (1,4) (2,2) (0,0)
sorted: (0,0) (0,4) (1,3) (1,4) (2,0) (2,2) (2,4)

          (1,4)
        /       \
     (0,4)     (2,2)
    /    \     /    \
 (0,0) (1,3) (2,0) (2,4)

如果直接这样建,建完观察 (0,4) 所在子树,不像 (2,2) 的子树某一维有序,这样不方便查询矩形,所以每次我们选择某一维的中位数,切割成两个子矩形,方便查找以及查找时的剪枝。

虽然这样操作后就不是二叉搜索树了,但是具有二叉搜索树的形态(大概就是这个意思),所以我们的矩形查询能被切割成两个子矩形,同时为了保证复杂度,每一维轮流选中位数。

before:
          (1,4)
        /       \
     (0,4)     (2,2)
    /    \     /    \
 (0,0) (1,3) (2,0) (2,4)

now:
          (1,4)
        /       \           sorted by x
     (1,3)     (2,2)
    /    \     /    \       sorted by y
 (0,0) (0,4) (2,0) (2,4)

或者看图。

这样我们要查找矩形 [(0,2),(2,4)] 就可以这样递归搜索。

                   included-(1,3) ---+                +--- included-(2,2)
          (1,4)                      |  [(0,2),(2,4)] |
        /       \    divided by x=1 -+--/           \-+- included-(1,4)
     (1,3)     (2,2)           [(0,2),(1,4)]    [(1,2),(2,4)]
    /    \     /    \         /  by y=3  \        /  by y=2  \
 (0,0) (0,4) (2,0) (2,4)  (none) included-(0,4) (skip)  included-(2,4)

亦或看图。你看得出来在打架,我想尝试让你懂。

复杂度证明

我们查询一个矩形有如下三个流程:

  1. 如果查询与当前结点的区域无交集,直接跳出;
  2. 如果查询将当前结点的区域包含,直接跳出并上传答案;
  3. 有交集但不包含,继续递归求解。

明显的,复杂度来源于第三类查询。

因为上文说到按照轮流按照每一维进行划分,这样平面就会划分为若干个矩形。

假设我们查询一条竖线左侧的结点(结点代表一个矩形),那么按照竖线(也就是 x 这一维)划分的结点都可以剪掉一半的结点。

引用一张来自 Wallace 的图。

可以清楚看到,每两层可以减一次枝,每个结点又有两个儿子,也就是说每隔一层点数翻倍。

因此复杂度是:

\sum_{i=0}^{\frac{h}{2}}2^i\approx2^{\frac{h}{2}}\approx2^{\frac{\log_2n}{2}}=\sqrt n.

其中 h 是树高,通过替罪羊维护期望为 O(\log_2n),类似的,k 维时复杂度为 O(n^{1-\frac{1}{k}})

代码实现

建树

建树时需要求中位数,并且要把小于中位数的放一边,剩下的放另一边。

你可能会说直接提前 sort 不行吗,确实不可以直接 sort,因为不同深度会按不同维度排序,所以你可以手打一个带层数的快排,以支持不同深度会按不同维度排序。

更好的懒人方法可以在建树的时候直接用 nth_element(),自动把中位数排到正确位置,同时可以自定义比较函数。

复杂度:上述的建树过程本质就是带层数的快排,显然 O(n\log n)

查询

为了支持高维矩形查询,我们需要记录每一个矩形中每一维度上的坐标的最大值和最小值,递归查询时发现与这个矩形无交(上面的 (skip))就跳过,否则递归到分割后的矩形。

复杂度见复杂度证明。

插入

如何实现插入,直接看作一个矩形搜到对应空结点新建即可。

但是如果往一个地方插入过多的结点就炸了,所以需要重构。

复杂度见复杂度证明。

重构

具体的,如果一个结点的较大子树大小超过预先设定的比例,就炸掉重构。

重构很简单,暴力遍历把点拎出来再重新建树,期望复杂度类似替罪羊,但是会对查询复杂度有所影响。

小优化:改为根号重构或二进制分组,简单修改即可,这样就能保证树高了。

复杂度:单次重构 O(n\log n),其中 n 是子树大小,均摊复杂度 O(n\log n)

区间加

还是搜到矩阵,打个 tag 就行了,以后每次搜到一个结点就 pushdown

实现细节

K 是维数。

使用宏定义以减少实现难度。

#define tu t[u]
#define lu t[tu.ls]
#define ru t[tu.rs]
#define ALPHA 0.7

点是单独维护的,单开个结构体,和结点不同,结点要维护子树信息。

struct Pnt
{
    int p[K], val;
};

结点

注意构造时清空,方便实现。

最好写成默认构造。

struct nde
{
    int ls, rs, low[K], hig[K], siz, sum, tag;
    Pnt p;
    nde()
    {
        ls = rs = siz = sum = tag = 0;
        fill(low, low + K, INF), fill(hig, hig + K, -INF);
    }
} t[N];

默认清空,方便以下调用。

维护树

包括了清空、获取新结点、建树和重构。

注意每次清空后要重设上下界,而这一步通过 nde() 的构造函数实现。

int poo[N];
int _vec;
Pnt vec[N];
int newnde()
{
    if (*poo)
    {
        return poo[(*poo)--];
    }
    return ++tot;
}
void clear(int u)
{
    if (!u)
    {
        return;
    }
    pushdown(u);
    clear(tu.ls);
    vec[++_vec] = tu.p;
    clear(tu.rs);
    tu = nde();
    poo[++*poo] = u;
}
int build(int l, int r, int dep)
{
    if (l > r)
    {
        return 0;
    }
    int u = newnde(), mid = (l + r) >> 1;
    nth_element(vec + l, vec + mid, vec + r + 1, [&](const Pnt &x, const Pnt &y)
                { return x.p[dep] < y.p[dep]; });
    tu.p = vec[mid];
    tu.ls = build(l, mid - 1, (dep + 1) % K);
    tu.rs = build(mid + 1, r, (dep + 1) % K);
    pushup(u);
    return u;
}
void check(int &u, int dep)
{
    if (tu.siz * ALPHA < max(lu.siz, ru.siz))
    {
        _vec = 0;
        clear(u);
        u = build(1, _vec, dep);
    }
}

修改

void pushup(int u)
{
    for (int k = 0; k < K; ++k)
    {
        tu.low[k] = min({lu.low[k], ru.low[k], tu.p.p[k]});
        tu.hig[k] = max({lu.hig[k], ru.hig[k], tu.p.p[k]});
    }
    tu.sum = lu.sum + ru.sum + tu.p.val;
    tu.siz = lu.siz + ru.siz + 1;
}
void down(int u, int tag)
{
    if (u)
    {
        tu.p.val += tag;
        tu.tag += tag;
        tu.sum += tu.siz * tag;
    }
}
void pushdown(int u)
{
    if (tu.tag)
    {
        down(tu.ls, tu.tag);
        down(tu.rs, tu.tag);
        tu.tag = 0;
        pushup(u);
    }
}

判断

相离或包含的辅助函数。

bool in(int p[K], int low[K], int hig[K])
{
    for (int k = 0; k < K; ++k)
    {
        if (!(low[k] <= p[k] && p[k] <= hig[k]))
        {
            return 0;
        }
    }
    return 1;
}
inline bool out(int x, int y, int l, int r)
{
    return y < l || r < x;
}

插入

注意插入的 newnde() 只会返回未使用过的结点,而我们重载了 nde() 默认构造函数,所以可以直接 pushup()

void insert(int &u, const Pnt &p, int dep)
{
    if (!u)
    {
        u = newnde();
        tu.p = p;
        pushup(u);
        return;
    }
    pushdown(u);
    if (p.p[dep] < tu.p.p[dep])
    {
        insert(tu.ls, p, (dep + 1) % K);
    }
    else
    {
        insert(tu.rs, p, (dep + 1) % K);
    }
    pushup(u);
    check(u, dep);
}

加和查

避免传参,减少常数。

int _low[K], _hig[K], _val;
int query(int low[K], int hig[K])
{
    if (tot == 0)
    {
        return 0;
    }
    copy(low, low + K, _low), copy(hig, hig + K, _hig);
    return _query(rt);
}
void add(int low[K], int hig[K], int val)
{
    if (tot == 0)
    {
        return;
    }
    copy(low, low + K, _low), copy(hig, hig + K, _hig);
    _val = val;
    _add(rt);
}
int _query(int u)
{
    if (!u)
    {
        return 0;
    }
    pushdown(u);
    if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
    {
        return tu.sum;
    }
    for (int k = 0; k < K; ++k)
    {
        if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
        {
            return 0;
        }
    }
    int res = 0;
    if (in(tu.p.p, _low, _hig))
    {
        res += tu.p.val;
    }
    return res + _query(tu.ls) + _query(tu.rs);
}
void _add(int u)
{
    if (!u)
    {
        return;
    }
    pushdown(u);
    if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
    {
        down(u, _val);
        return;
    }
    for (int k = 0; k < K; ++k)
    {
        if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
        {
            return;
        }
    }
    if (in(tu.p.p, _low, _hig))
    {
        tu.p.val += _val;
    }
    _add(tu.ls), _add(tu.rs);
    pushup(u);
}

以下是额外补的。

建树

对于这道题不需要提前对若干结点建树,若需要的话,可以类似这样(另一份代码里拷的,自行类比)。

cin >> n;
for (int i = 1; i <= n; ++i)
{
    cin >> t.vec[i].x >> t.vec[i].y >> t.vec[i].z;
    t.vec[i].cnt = 1;
}
t.rt = t.build(1, n, 0);

删点

曾经被坑过,乱删复杂度就炸了,或是 RE 死活调不出来。

可以插入负的点权,如数点就加入 cnt=-1 的点。

Code

拼起来,同时复制一份改成三维,就有了 13K,因为封装了又套了层 namespace 避免重名,光是 tab 就有 6K

:::info[屎山]

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int ll

const int INF = 1e18;
const int MOD = 998244353;
using it2 = array<int, 2>;

#define tu t[u]
#define lu t[tu.ls]
#define ru t[tu.rs]
#define ALPHA 0.7
namespace KDT2
{
    constexpr int K = 2;
    constexpr int N = 1.5e5 + 3;
    struct Pnt
    {
        int p[K], val;
    };
    struct KDT
    {
        int tot, rt;
        struct nde
        {
            int ls, rs, low[K], hig[K], siz, sum, tag;
            Pnt p;
            nde()
            {
                ls = rs = siz = sum = tag = 0;
                fill(low, low + K, INF), fill(hig, hig + K, -INF);
            }
        } t[N];
        int poo[N];
        int _vec;
        Pnt vec[N];
        void pushup(int u)
        {
            for (int k = 0; k < K; ++k)
            {
                tu.low[k] = min({lu.low[k], ru.low[k], tu.p.p[k]});
                tu.hig[k] = max({lu.hig[k], ru.hig[k], tu.p.p[k]});
            }
            tu.sum = lu.sum + ru.sum + tu.p.val;
            tu.siz = lu.siz + ru.siz + 1;
        }
        void down(int u, int tag)
        {
            if (u)
            {
                tu.p.val += tag;
                tu.tag += tag;
                tu.sum += tu.siz * tag;
            }
        }
        void pushdown(int u)
        {
            if (tu.tag)
            {
                down(tu.ls, tu.tag);
                down(tu.rs, tu.tag);
                tu.tag = 0;
                pushup(u);
            }
        }
        int newnde()
        {
            if (*poo)
            {
                return poo[(*poo)--];
            }
            return ++tot;
        }
        void clear(int u)
        {
            if (!u)
            {
                return;
            }
            pushdown(u);
            clear(tu.ls);
            vec[++_vec] = tu.p;
            clear(tu.rs);
            tu = nde();
            poo[++*poo] = u;
        }
        bool in(int p[K], int low[K], int hig[K])
        {
            for (int k = 0; k < K; ++k)
            {
                if (!(low[k] <= p[k] && p[k] <= hig[k]))
                {
                    return 0;
                }
            }
            return 1;
        }
        inline bool out(int x, int y, int l, int r)
        {
            return y < l || r < x;
        }
        int build(int l, int r, int dep)
        {
            if (l > r)
            {
                return 0;
            }
            int u = newnde(), mid = (l + r) >> 1;
            nth_element(vec + l, vec + mid, vec + r + 1, [&](const Pnt &x, const Pnt &y)
                        { return x.p[dep] < y.p[dep]; });
            tu.p = vec[mid];
            tu.ls = build(l, mid - 1, (dep + 1) % K);
            tu.rs = build(mid + 1, r, (dep + 1) % K);
            pushup(u);
            return u;
        }
        void check(int &u, int dep)
        {
            if (tu.siz * ALPHA < max(lu.siz, ru.siz))
            {
                _vec = 0;
                clear(u);
                u = build(1, _vec, dep);
            }
        }
        void insert(int &u, const Pnt &p, int dep)
        {
            if (!u)
            {
                u = newnde();
                tu.p = p;
                pushup(u);
                return;
            }
            pushdown(u);
            if (p.p[dep] < tu.p.p[dep])
            {
                insert(tu.ls, p, (dep + 1) % K);
            }
            else
            {
                insert(tu.rs, p, (dep + 1) % K);
            }
            pushup(u);
            check(u, dep);
        }
        int _low[K], _hig[K], _val;
        int query(int low[K], int hig[K])
        {
            if (tot == 0)
            {
                return 0;
            }
            copy(low, low + K, _low), copy(hig, hig + K, _hig);
            return _query(rt);
        }
        void add(int low[K], int hig[K], int val)
        {
            if (tot == 0)
            {
                return;
            }
            copy(low, low + K, _low), copy(hig, hig + K, _hig);
            _val = val;
            _add(rt);
        }
        int _query(int u)
        {
            if (!u)
            {
                return 0;
            }
            pushdown(u);
            if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
            {
                return tu.sum;
            }
            for (int k = 0; k < K; ++k)
            {
                if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
                {
                    return 0;
                }
            }
            int res = 0;
            if (in(tu.p.p, _low, _hig))
            {
                res += tu.p.val;
            }
            return res + _query(tu.ls) + _query(tu.rs);
        }
        void _add(int u)
        {
            if (!u)
            {
                return;
            }
            pushdown(u);
            if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
            {
                down(u, _val);
                return;
            }
            for (int k = 0; k < K; ++k)
            {
                if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
                {
                    return;
                }
            }
            if (in(tu.p.p, _low, _hig))
            {
                tu.p.val += _val;
            }
            _add(tu.ls), _add(tu.rs);
            pushup(u);
        }
    };
}
namespace KDT3
{
    constexpr int K = 3;
    constexpr int N = 1e5 + 3;
    struct Pnt
    {
        int p[K], val;
    };
    struct KDT
    {
        int tot, rt;
        struct nde
        {
            int ls, rs, low[K], hig[K], siz, sum, tag;
            Pnt p;
            nde()
            {
                ls = rs = siz = sum = tag = 0;
                fill(low, low + K, INF), fill(hig, hig + K, -INF);
            }
        } t[N];
        int poo[N];
        int _vec;
        Pnt vec[N];
        void pushup(int u)
        {
            for (int k = 0; k < K; ++k)
            {
                tu.low[k] = min({lu.low[k], ru.low[k], tu.p.p[k]});
                tu.hig[k] = max({lu.hig[k], ru.hig[k], tu.p.p[k]});
            }
            tu.sum = lu.sum + ru.sum + tu.p.val;
            tu.siz = lu.siz + ru.siz + 1;
        }
        void down(int u, int tag)
        {
            if (u)
            {
                tu.p.val += tag;
                tu.tag += tag;
                tu.sum += tu.siz * tag;
            }
        }
        void pushdown(int u)
        {
            if (tu.tag)
            {
                down(tu.ls, tu.tag);
                down(tu.rs, tu.tag);
                tu.tag = 0;
                pushup(u);
            }
        }
        int newnde()
        {
            if (*poo)
            {
                return poo[(*poo)--];
            }
            return ++tot;
        }
        void clear(int u)
        {
            if (!u)
            {
                return;
            }
            pushdown(u);
            clear(tu.ls);
            vec[++_vec] = tu.p;
            clear(tu.rs);
            tu = nde();
            poo[++*poo] = u;
        }
        bool in(int p[K], int low[K], int hig[K])
        {
            for (int k = 0; k < K; ++k)
            {
                if (!(low[k] <= p[k] && p[k] <= hig[k]))
                {
                    return 0;
                }
            }
            return 1;
        }
        inline bool out(int x, int y, int l, int r)
        {
            return y < l || r < x;
        }
        int build(int l, int r, int dep)
        {
            if (l > r)
            {
                return 0;
            }
            int u = newnde(), mid = (l + r) >> 1;
            nth_element(vec + l, vec + mid, vec + r + 1, [&](const Pnt &x, const Pnt &y)
                        { return x.p[dep] < y.p[dep]; });
            tu.p = vec[mid];
            tu.ls = build(l, mid - 1, (dep + 1) % K);
            tu.rs = build(mid + 1, r, (dep + 1) % K);
            pushup(u);
            return u;
        }
        void check(int &u, int dep)
        {
            if (tu.siz * ALPHA < max(lu.siz, ru.siz))
            {
                _vec = 0;
                clear(u);
                u = build(1, _vec, dep);
            }
        }
        void insert(int &u, const Pnt &p, int dep)
        {
            if (!u)
            {
                u = newnde();
                tu.p = p;
                pushup(u);
                return;
            }
            pushdown(u);
            if (p.p[dep] < tu.p.p[dep])
            {
                insert(tu.ls, p, (dep + 1) % K);
            }
            else
            {
                insert(tu.rs, p, (dep + 1) % K);
            }
            pushup(u);
            check(u, dep);
        }
        int _low[K], _hig[K], _val;
        int query(int low[K], int hig[K])
        {
            if (tot == 0)
            {
                return 0;
            }
            copy(low, low + K, _low), copy(hig, hig + K, _hig);
            return _query(rt);
        }
        void add(int low[K], int hig[K], int val)
        {
            if (tot == 0)
            {
                return;
            }
            copy(low, low + K, _low), copy(hig, hig + K, _hig);
            _val = val;
            _add(rt);
        }
        int _query(int u)
        {
            if (!u)
            {
                return 0;
            }
            pushdown(u);
            if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
            {
                return tu.sum;
            }
            for (int k = 0; k < K; ++k)
            {
                if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
                {
                    return 0;
                }
            }
            int res = 0;
            if (in(tu.p.p, _low, _hig))
            {
                res += tu.p.val;
            }
            return res + _query(tu.ls) + _query(tu.rs);
        }
        void _add(int u)
        {
            if (!u)
            {
                return;
            }
            pushdown(u);
            if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
            {
                down(u, _val);
                return;
            }
            for (int k = 0; k < K; ++k)
            {
                if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
                {
                    return;
                }
            }
            if (in(tu.p.p, _low, _hig))
            {
                tu.p.val += _val;
            }
            _add(tu.ls), _add(tu.rs);
            pushup(u);
        }
    };
}

KDT2::KDT kdt2;
KDT3::KDT kdt3;

signed main()
{
    cin.tie(0)->sync_with_stdio(false), cout.setf(ios::fixed), cout.precision(10);

    int k, m, lst = 0;
    assert(cin >> k >> m);
    if (k == 2)
    {
        const int K = 2;
        int op, low[K], hig[K], val;
        KDT2::Pnt tmp;
        while (m--)
        {
            cin >> op;
            if (op == 1)
            {
                for (int k = 0; k < K; ++k)
                {
                    cin >> tmp.p[k], tmp.p[k] ^= lst;
                }
                cin >> tmp.val, tmp.val ^= lst;
                kdt2.insert(kdt2.rt, tmp, 0);
            }
            else if (op == 2)
            {
                for (int k = 0; k < K; ++k)
                {
                    cin >> low[k], low[k] ^= lst;
                }
                for (int k = 0; k < K; ++k)
                {
                    cin >> hig[k], hig[k] ^= lst;
                }
                cin >> val, val ^= lst;
                kdt2.add(low, hig, val);
            }
            else
            {
                for (int k = 0; k < K; ++k)
                {
                    cin >> low[k], low[k] ^= lst;
                }
                for (int k = 0; k < K; ++k)
                {
                    cin >> hig[k], hig[k] ^= lst;
                }
                cout << (lst = kdt2.query(low, hig)) << '\n';
            }
        }
    }
    else
    {
        const int K = 3;
        int op, low[K], hig[K], val;
        KDT3::Pnt tmp;
        while (m--)
        {
            cin >> op;
            if (op == 1)
            {
                for (int k = 0; k < K; ++k)
                {
                    cin >> tmp.p[k], tmp.p[k] ^= lst;
                }
                cin >> tmp.val, tmp.val ^= lst;
                kdt3.insert(kdt3.rt, tmp, 0);
            }
            else if (op == 2)
            {
                for (int k = 0; k < K; ++k)
                {
                    cin >> low[k], low[k] ^= lst;
                }
                for (int k = 0; k < K; ++k)
                {
                    cin >> hig[k], hig[k] ^= lst;
                }
                cin >> val, val ^= lst;
                kdt3.add(low, hig, val);
            }
            else
            {
                for (int k = 0; k < K; ++k)
                {
                    cin >> low[k], low[k] ^= lst;
                }
                for (int k = 0; k < K; ++k)
                {
                    cin >> hig[k], hig[k] ^= lst;
                }
                cout << (lst = kdt3.query(low, hig)) << '\n';
            }
        }
    }

    return 0;
}

:::