P14312 【模板】K-D Tree の题解
kdt 略解
简介:本文主要介绍 kdt 的结构体封装和替罪羊维护写法。
其实我都不会替罪羊,只会替罪羊维护的 kdt,同时更新了复杂度证明。
题意
本题需要支持的操作:
- 插点;
- 矩形(立方体)加;
- 矩形(立方体)求和。
然后卡空间,时限很松。
这不就是 kdt 板题吗。
算法介绍
应用
可以高效支持在线高维矩形查询问题。
也可以通过剪枝解决一些平面点对问题。
原理
以下只讨论二维情况,三维请自行类比。
kdt 具有二叉搜索树的形态,二叉搜索树人话就是中序遍历是有序的带权二叉树。
既然中序遍历是有序的可以搜索,我们可以通过每次选中位数做根,建出一棵树,例如:
before: (0,4) (2,0) (1,3) (2,4) (1,4) (2,2) (0,0)
sorted: (0,0) (0,4) (1,3) (1,4) (2,0) (2,2) (2,4)
(1,4)
/ \
(0,4) (2,2)
/ \ / \
(0,0) (1,3) (2,0) (2,4)
如果直接这样建,建完观察 (0,4) 所在子树,不像 (2,2) 的子树某一维有序,这样不方便查询矩形,所以每次我们选择某一维的中位数,切割成两个子矩形,方便查找以及查找时的剪枝。
虽然这样操作后就不是二叉搜索树了,但是具有二叉搜索树的形态(大概就是这个意思),所以我们的矩形查询能被切割成两个子矩形,同时为了保证复杂度,每一维轮流选中位数。
before:
(1,4)
/ \
(0,4) (2,2)
/ \ / \
(0,0) (1,3) (2,0) (2,4)
now:
(1,4)
/ \ sorted by x
(1,3) (2,2)
/ \ / \ sorted by y
(0,0) (0,4) (2,0) (2,4)
或者看图。
这样我们要查找矩形 [(0,2),(2,4)] 就可以这样递归搜索。
included-(1,3) ---+ +--- included-(2,2)
(1,4) | [(0,2),(2,4)] |
/ \ divided by x=1 -+--/ \-+- included-(1,4)
(1,3) (2,2) [(0,2),(1,4)] [(1,2),(2,4)]
/ \ / \ / by y=3 \ / by y=2 \
(0,0) (0,4) (2,0) (2,4) (none) included-(0,4) (skip) included-(2,4)
亦或看图。你看得出来在打架,我想尝试让你懂。
复杂度证明
我们查询一个矩形有如下三个流程:
- 如果查询与当前结点的区域无交集,直接跳出;
- 如果查询将当前结点的区域包含,直接跳出并上传答案;
- 有交集但不包含,继续递归求解。
明显的,复杂度来源于第三类查询。
因为上文说到按照轮流按照每一维进行划分,这样平面就会划分为若干个矩形。
假设我们查询一条竖线左侧的结点(结点代表一个矩形),那么按照竖线(也就是
引用一张来自 Wallace 的图。
可以清楚看到,每两层可以减一次枝,每个结点又有两个儿子,也就是说每隔一层点数翻倍。
因此复杂度是:
其中
代码实现
建树
建树时需要求中位数,并且要把小于中位数的放一边,剩下的放另一边。
你可能会说直接提前 sort 不行吗,确实不可以直接 sort,因为不同深度会按不同维度排序,所以你可以手打一个带层数的快排,以支持不同深度会按不同维度排序。
更好的懒人方法可以在建树的时候直接用 nth_element(),自动把中位数排到正确位置,同时可以自定义比较函数。
复杂度:上述的建树过程本质就是带层数的快排,显然
查询
为了支持高维矩形查询,我们需要记录每一个矩形中每一维度上的坐标的最大值和最小值,递归查询时发现与这个矩形无交(上面的 (skip))就跳过,否则递归到分割后的矩形。
复杂度见复杂度证明。
插入
如何实现插入,直接看作一个矩形搜到对应空结点新建即可。
但是如果往一个地方插入过多的结点就炸了,所以需要重构。
复杂度见复杂度证明。
重构
具体的,如果一个结点的较大子树大小超过预先设定的比例,就炸掉重构。
重构很简单,暴力遍历把点拎出来再重新建树,期望复杂度类似替罪羊,但是会对查询复杂度有所影响。
小优化:改为根号重构或二进制分组,简单修改即可,这样就能保证树高了。
复杂度:单次重构
区间加
还是搜到矩阵,打个 tag 就行了,以后每次搜到一个结点就 pushdown。
实现细节
K 是维数。
使用宏定义以减少实现难度。
#define tu t[u]
#define lu t[tu.ls]
#define ru t[tu.rs]
#define ALPHA 0.7
点
点是单独维护的,单开个结构体,和结点不同,结点要维护子树信息。
struct Pnt
{
int p[K], val;
};
结点
注意构造时清空,方便实现。
最好写成默认构造。
struct nde
{
int ls, rs, low[K], hig[K], siz, sum, tag;
Pnt p;
nde()
{
ls = rs = siz = sum = tag = 0;
fill(low, low + K, INF), fill(hig, hig + K, -INF);
}
} t[N];
默认清空,方便以下调用。
维护树
包括了清空、获取新结点、建树和重构。
注意每次清空后要重设上下界,而这一步通过 nde() 的构造函数实现。
int poo[N];
int _vec;
Pnt vec[N];
int newnde()
{
if (*poo)
{
return poo[(*poo)--];
}
return ++tot;
}
void clear(int u)
{
if (!u)
{
return;
}
pushdown(u);
clear(tu.ls);
vec[++_vec] = tu.p;
clear(tu.rs);
tu = nde();
poo[++*poo] = u;
}
int build(int l, int r, int dep)
{
if (l > r)
{
return 0;
}
int u = newnde(), mid = (l + r) >> 1;
nth_element(vec + l, vec + mid, vec + r + 1, [&](const Pnt &x, const Pnt &y)
{ return x.p[dep] < y.p[dep]; });
tu.p = vec[mid];
tu.ls = build(l, mid - 1, (dep + 1) % K);
tu.rs = build(mid + 1, r, (dep + 1) % K);
pushup(u);
return u;
}
void check(int &u, int dep)
{
if (tu.siz * ALPHA < max(lu.siz, ru.siz))
{
_vec = 0;
clear(u);
u = build(1, _vec, dep);
}
}
修改
void pushup(int u)
{
for (int k = 0; k < K; ++k)
{
tu.low[k] = min({lu.low[k], ru.low[k], tu.p.p[k]});
tu.hig[k] = max({lu.hig[k], ru.hig[k], tu.p.p[k]});
}
tu.sum = lu.sum + ru.sum + tu.p.val;
tu.siz = lu.siz + ru.siz + 1;
}
void down(int u, int tag)
{
if (u)
{
tu.p.val += tag;
tu.tag += tag;
tu.sum += tu.siz * tag;
}
}
void pushdown(int u)
{
if (tu.tag)
{
down(tu.ls, tu.tag);
down(tu.rs, tu.tag);
tu.tag = 0;
pushup(u);
}
}
判断
相离或包含的辅助函数。
bool in(int p[K], int low[K], int hig[K])
{
for (int k = 0; k < K; ++k)
{
if (!(low[k] <= p[k] && p[k] <= hig[k]))
{
return 0;
}
}
return 1;
}
inline bool out(int x, int y, int l, int r)
{
return y < l || r < x;
}
插入
注意插入的 newnde() 只会返回未使用过的结点,而我们重载了 nde() 默认构造函数,所以可以直接 pushup()。
void insert(int &u, const Pnt &p, int dep)
{
if (!u)
{
u = newnde();
tu.p = p;
pushup(u);
return;
}
pushdown(u);
if (p.p[dep] < tu.p.p[dep])
{
insert(tu.ls, p, (dep + 1) % K);
}
else
{
insert(tu.rs, p, (dep + 1) % K);
}
pushup(u);
check(u, dep);
}
加和查
避免传参,减少常数。
int _low[K], _hig[K], _val;
int query(int low[K], int hig[K])
{
if (tot == 0)
{
return 0;
}
copy(low, low + K, _low), copy(hig, hig + K, _hig);
return _query(rt);
}
void add(int low[K], int hig[K], int val)
{
if (tot == 0)
{
return;
}
copy(low, low + K, _low), copy(hig, hig + K, _hig);
_val = val;
_add(rt);
}
int _query(int u)
{
if (!u)
{
return 0;
}
pushdown(u);
if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
{
return tu.sum;
}
for (int k = 0; k < K; ++k)
{
if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
{
return 0;
}
}
int res = 0;
if (in(tu.p.p, _low, _hig))
{
res += tu.p.val;
}
return res + _query(tu.ls) + _query(tu.rs);
}
void _add(int u)
{
if (!u)
{
return;
}
pushdown(u);
if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
{
down(u, _val);
return;
}
for (int k = 0; k < K; ++k)
{
if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
{
return;
}
}
if (in(tu.p.p, _low, _hig))
{
tu.p.val += _val;
}
_add(tu.ls), _add(tu.rs);
pushup(u);
}
以下是额外补的。
建树
对于这道题不需要提前对若干结点建树,若需要的话,可以类似这样(另一份代码里拷的,自行类比)。
cin >> n;
for (int i = 1; i <= n; ++i)
{
cin >> t.vec[i].x >> t.vec[i].y >> t.vec[i].z;
t.vec[i].cnt = 1;
}
t.rt = t.build(1, n, 0);
删点
曾经被坑过,乱删复杂度就炸了,或是 RE 死活调不出来。
可以插入负的点权,如数点就加入 cnt=-1 的点。
Code
拼起来,同时复制一份改成三维,就有了 namespace 避免重名,光是 tab 就有
:::info[屎山]
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int ll
const int INF = 1e18;
const int MOD = 998244353;
using it2 = array<int, 2>;
#define tu t[u]
#define lu t[tu.ls]
#define ru t[tu.rs]
#define ALPHA 0.7
namespace KDT2
{
constexpr int K = 2;
constexpr int N = 1.5e5 + 3;
struct Pnt
{
int p[K], val;
};
struct KDT
{
int tot, rt;
struct nde
{
int ls, rs, low[K], hig[K], siz, sum, tag;
Pnt p;
nde()
{
ls = rs = siz = sum = tag = 0;
fill(low, low + K, INF), fill(hig, hig + K, -INF);
}
} t[N];
int poo[N];
int _vec;
Pnt vec[N];
void pushup(int u)
{
for (int k = 0; k < K; ++k)
{
tu.low[k] = min({lu.low[k], ru.low[k], tu.p.p[k]});
tu.hig[k] = max({lu.hig[k], ru.hig[k], tu.p.p[k]});
}
tu.sum = lu.sum + ru.sum + tu.p.val;
tu.siz = lu.siz + ru.siz + 1;
}
void down(int u, int tag)
{
if (u)
{
tu.p.val += tag;
tu.tag += tag;
tu.sum += tu.siz * tag;
}
}
void pushdown(int u)
{
if (tu.tag)
{
down(tu.ls, tu.tag);
down(tu.rs, tu.tag);
tu.tag = 0;
pushup(u);
}
}
int newnde()
{
if (*poo)
{
return poo[(*poo)--];
}
return ++tot;
}
void clear(int u)
{
if (!u)
{
return;
}
pushdown(u);
clear(tu.ls);
vec[++_vec] = tu.p;
clear(tu.rs);
tu = nde();
poo[++*poo] = u;
}
bool in(int p[K], int low[K], int hig[K])
{
for (int k = 0; k < K; ++k)
{
if (!(low[k] <= p[k] && p[k] <= hig[k]))
{
return 0;
}
}
return 1;
}
inline bool out(int x, int y, int l, int r)
{
return y < l || r < x;
}
int build(int l, int r, int dep)
{
if (l > r)
{
return 0;
}
int u = newnde(), mid = (l + r) >> 1;
nth_element(vec + l, vec + mid, vec + r + 1, [&](const Pnt &x, const Pnt &y)
{ return x.p[dep] < y.p[dep]; });
tu.p = vec[mid];
tu.ls = build(l, mid - 1, (dep + 1) % K);
tu.rs = build(mid + 1, r, (dep + 1) % K);
pushup(u);
return u;
}
void check(int &u, int dep)
{
if (tu.siz * ALPHA < max(lu.siz, ru.siz))
{
_vec = 0;
clear(u);
u = build(1, _vec, dep);
}
}
void insert(int &u, const Pnt &p, int dep)
{
if (!u)
{
u = newnde();
tu.p = p;
pushup(u);
return;
}
pushdown(u);
if (p.p[dep] < tu.p.p[dep])
{
insert(tu.ls, p, (dep + 1) % K);
}
else
{
insert(tu.rs, p, (dep + 1) % K);
}
pushup(u);
check(u, dep);
}
int _low[K], _hig[K], _val;
int query(int low[K], int hig[K])
{
if (tot == 0)
{
return 0;
}
copy(low, low + K, _low), copy(hig, hig + K, _hig);
return _query(rt);
}
void add(int low[K], int hig[K], int val)
{
if (tot == 0)
{
return;
}
copy(low, low + K, _low), copy(hig, hig + K, _hig);
_val = val;
_add(rt);
}
int _query(int u)
{
if (!u)
{
return 0;
}
pushdown(u);
if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
{
return tu.sum;
}
for (int k = 0; k < K; ++k)
{
if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
{
return 0;
}
}
int res = 0;
if (in(tu.p.p, _low, _hig))
{
res += tu.p.val;
}
return res + _query(tu.ls) + _query(tu.rs);
}
void _add(int u)
{
if (!u)
{
return;
}
pushdown(u);
if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
{
down(u, _val);
return;
}
for (int k = 0; k < K; ++k)
{
if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
{
return;
}
}
if (in(tu.p.p, _low, _hig))
{
tu.p.val += _val;
}
_add(tu.ls), _add(tu.rs);
pushup(u);
}
};
}
namespace KDT3
{
constexpr int K = 3;
constexpr int N = 1e5 + 3;
struct Pnt
{
int p[K], val;
};
struct KDT
{
int tot, rt;
struct nde
{
int ls, rs, low[K], hig[K], siz, sum, tag;
Pnt p;
nde()
{
ls = rs = siz = sum = tag = 0;
fill(low, low + K, INF), fill(hig, hig + K, -INF);
}
} t[N];
int poo[N];
int _vec;
Pnt vec[N];
void pushup(int u)
{
for (int k = 0; k < K; ++k)
{
tu.low[k] = min({lu.low[k], ru.low[k], tu.p.p[k]});
tu.hig[k] = max({lu.hig[k], ru.hig[k], tu.p.p[k]});
}
tu.sum = lu.sum + ru.sum + tu.p.val;
tu.siz = lu.siz + ru.siz + 1;
}
void down(int u, int tag)
{
if (u)
{
tu.p.val += tag;
tu.tag += tag;
tu.sum += tu.siz * tag;
}
}
void pushdown(int u)
{
if (tu.tag)
{
down(tu.ls, tu.tag);
down(tu.rs, tu.tag);
tu.tag = 0;
pushup(u);
}
}
int newnde()
{
if (*poo)
{
return poo[(*poo)--];
}
return ++tot;
}
void clear(int u)
{
if (!u)
{
return;
}
pushdown(u);
clear(tu.ls);
vec[++_vec] = tu.p;
clear(tu.rs);
tu = nde();
poo[++*poo] = u;
}
bool in(int p[K], int low[K], int hig[K])
{
for (int k = 0; k < K; ++k)
{
if (!(low[k] <= p[k] && p[k] <= hig[k]))
{
return 0;
}
}
return 1;
}
inline bool out(int x, int y, int l, int r)
{
return y < l || r < x;
}
int build(int l, int r, int dep)
{
if (l > r)
{
return 0;
}
int u = newnde(), mid = (l + r) >> 1;
nth_element(vec + l, vec + mid, vec + r + 1, [&](const Pnt &x, const Pnt &y)
{ return x.p[dep] < y.p[dep]; });
tu.p = vec[mid];
tu.ls = build(l, mid - 1, (dep + 1) % K);
tu.rs = build(mid + 1, r, (dep + 1) % K);
pushup(u);
return u;
}
void check(int &u, int dep)
{
if (tu.siz * ALPHA < max(lu.siz, ru.siz))
{
_vec = 0;
clear(u);
u = build(1, _vec, dep);
}
}
void insert(int &u, const Pnt &p, int dep)
{
if (!u)
{
u = newnde();
tu.p = p;
pushup(u);
return;
}
pushdown(u);
if (p.p[dep] < tu.p.p[dep])
{
insert(tu.ls, p, (dep + 1) % K);
}
else
{
insert(tu.rs, p, (dep + 1) % K);
}
pushup(u);
check(u, dep);
}
int _low[K], _hig[K], _val;
int query(int low[K], int hig[K])
{
if (tot == 0)
{
return 0;
}
copy(low, low + K, _low), copy(hig, hig + K, _hig);
return _query(rt);
}
void add(int low[K], int hig[K], int val)
{
if (tot == 0)
{
return;
}
copy(low, low + K, _low), copy(hig, hig + K, _hig);
_val = val;
_add(rt);
}
int _query(int u)
{
if (!u)
{
return 0;
}
pushdown(u);
if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
{
return tu.sum;
}
for (int k = 0; k < K; ++k)
{
if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
{
return 0;
}
}
int res = 0;
if (in(tu.p.p, _low, _hig))
{
res += tu.p.val;
}
return res + _query(tu.ls) + _query(tu.rs);
}
void _add(int u)
{
if (!u)
{
return;
}
pushdown(u);
if (in(tu.low, _low, _hig) && in(tu.hig, _low, _hig))
{
down(u, _val);
return;
}
for (int k = 0; k < K; ++k)
{
if (out(tu.low[k], tu.hig[k], _low[k], _hig[k]))
{
return;
}
}
if (in(tu.p.p, _low, _hig))
{
tu.p.val += _val;
}
_add(tu.ls), _add(tu.rs);
pushup(u);
}
};
}
KDT2::KDT kdt2;
KDT3::KDT kdt3;
signed main()
{
cin.tie(0)->sync_with_stdio(false), cout.setf(ios::fixed), cout.precision(10);
int k, m, lst = 0;
assert(cin >> k >> m);
if (k == 2)
{
const int K = 2;
int op, low[K], hig[K], val;
KDT2::Pnt tmp;
while (m--)
{
cin >> op;
if (op == 1)
{
for (int k = 0; k < K; ++k)
{
cin >> tmp.p[k], tmp.p[k] ^= lst;
}
cin >> tmp.val, tmp.val ^= lst;
kdt2.insert(kdt2.rt, tmp, 0);
}
else if (op == 2)
{
for (int k = 0; k < K; ++k)
{
cin >> low[k], low[k] ^= lst;
}
for (int k = 0; k < K; ++k)
{
cin >> hig[k], hig[k] ^= lst;
}
cin >> val, val ^= lst;
kdt2.add(low, hig, val);
}
else
{
for (int k = 0; k < K; ++k)
{
cin >> low[k], low[k] ^= lst;
}
for (int k = 0; k < K; ++k)
{
cin >> hig[k], hig[k] ^= lst;
}
cout << (lst = kdt2.query(low, hig)) << '\n';
}
}
}
else
{
const int K = 3;
int op, low[K], hig[K], val;
KDT3::Pnt tmp;
while (m--)
{
cin >> op;
if (op == 1)
{
for (int k = 0; k < K; ++k)
{
cin >> tmp.p[k], tmp.p[k] ^= lst;
}
cin >> tmp.val, tmp.val ^= lst;
kdt3.insert(kdt3.rt, tmp, 0);
}
else if (op == 2)
{
for (int k = 0; k < K; ++k)
{
cin >> low[k], low[k] ^= lst;
}
for (int k = 0; k < K; ++k)
{
cin >> hig[k], hig[k] ^= lst;
}
cin >> val, val ^= lst;
kdt3.add(low, hig, val);
}
else
{
for (int k = 0; k < K; ++k)
{
cin >> low[k], low[k] ^= lst;
}
for (int k = 0; k < K; ++k)
{
cin >> hig[k], hig[k] ^= lst;
}
cout << (lst = kdt3.query(low, hig)) << '\n';
}
}
}
return 0;
}
:::