题解:P14312 【模板】K-D Tree

· · 题解

K-D Tree 是一种很巧妙的数据结构,可以高效维护一个 k 维空间。

限于个人水平,文中可能存在不足之处,望不吝批评指正。

K-D Tree 简介

K-D Tree 是一种维护 k 维空间信息的二叉树,其每个节点代表空间的一部分,每个节点的左右子树都是两个不相交的 k 维空间。

相比于 cdq 分治,K-D Tree 可以处理强制在线的问题;相比于树套树,K-D Tree 的内存消耗又很小。

在 OI 中,一般有 k = 2k = 3,极少数情况下会用到 k = 4,很少见到 k 很大的题。因此,本文中所有复杂度分析都认为 k 是常数。

建树(Build)

下面我们以 k = 2 为例,讲述 K-D Tree 的建树过程。

第一步,我们需要选择一个维度,然后把空间分割。分割时要找到这个维度的中位数为分界线,把空间内的元素均匀地分成两部分。以下图为例:

图中,我们选择了横坐标的维度,然后找到了这个维度的中位数(红色的点),并把空间分成了左右两个子树。当 k = 3 时同理,空间会被划分为两个长方体。

第二步,我们需要继续划分空间,但应当换一个维度。具体地,以 k = 2 为例,要横着切一下,再竖着切一下,交替进行。

下面解释一下为什么要交替进行。不妨假设我每次都是竖着切的,不更换维度。那么,此时我构造一组横坐标相同的点(类似于一条链),就可以使 K-D Tree 退化成链了,复杂度自然爆炸。

因此,我们需要对每个维度轮流处理。于是第二次划分后 K-D Tree 长成这样:

然后递归下去即可。这种建树的方式保证了树高为 \mathcal{O}(\log n) 级别,使得总时间复杂度正确。

查询(Query)

在 K-D Tree 上查询是简单的。如果你理解线段树,则能够更快理解这一部分。

在线段树上查询时,我们从根节点开始往下递归,每次检查节点对应的区间于查询区间的关系。如果发现两个区间不交,则没有贡献;如果发现查询的区间包含这个节点对应的区间,则可以直接加上节点的权值和;否则,需要递归下去处理。

对应到 K-D Tree 上是一样的。如果没有交集,就没有贡献;如果完全包含,可以直接加上权值;否则,才需要继续递归下去求解。下面简单说明 k = 2 时其时间复杂度为 \mathcal{O}(\sqrt{n})

以上图(把平面分为四个部分的)为例,容易发现,无论我怎么画水平或竖直的线,都最多穿过两个矩形(不考虑在边界上的情况,因为总能够通过微小的扰动解决)。因此,这种情况下一次最多经过两个被分割的矩形。

因此有 T(n) = 2 + 2T(\frac{n}{4})。显然递归层数为 \log_4n,根据等比数列求和公式,可以得到其时间复杂度为 \mathcal{O}(2^{\log_4 n}) = \mathcal{O}(n^{\log_4 2}) = \mathcal{O}(\sqrt{n})

同理可以扩展到 k 维的情况,时间复杂度为 \mathcal{O}(n^{\frac{k - 1}{k}})

修改(Modify)

修改与查询类似,也是递归处理。

类似于线段树,我们递归下去,遇到严格包含的区间就打懒标记然后结束递归,否则递归处理。

不过要注意,类似线段树,递归处理前要下放懒标记,即调用 push_down 函数。

时间复杂度的证明与查询部分一致,这里不再赘述。

插入(Insert)

下面考虑在 K-D Tree 中插入一个点,这是线段树中没有的操作。

首先暴力插入肯定是错的,因为直接插进去的话二叉树就不平衡了。

但是,K-D Tree 是不可旋转的,因此也不能像平衡树那样处理。

因此,添加节点导致的不平衡只能通过重构解决。

市面上常见的做法有替罪羊树、根号重构、二进制分组等等。当然,也存在一些比较难的论文做法,例如 qbf 老师的集训队论文中提到的空间优化型做法。

本文主要讲解二进制分组做法,因为确实速度较快,且容易在考场上实现。

这个做法顾名思义,就是按照二进制,建 \log n 棵 K-D Tree,大小为 2 的非负整数次幂。每次插入一个新的节点,我们会发现总点数 n 增加 1,不妨把这个过程放到二进制上去思考。每次插入一个数,最低位会增加 1,此时插到第一棵大小为 2^0 = 1 的 K-D Tree 中。接下来,如果发现目前 K-D Tree 的大小超过了规定的大小(例如第二次插入,第一棵树就有 2 > 1 个节点),就要进行进位操作。对应到树上,就是两棵 K-D Tree 发生了合并,然后对应低位的 K-D Tree 清空。当然,有可能完成进位后下一位又爆了,此时需要继续做进位。每次查询和修改都在这 \mathcal{O}(\log n) 棵 K-D Tree 上全部做一遍即可。

最后我们只需要解决 K-D Tree 合并的细节了。每次合并的时候,我们暴力把两棵 K-D Tree 上的节点全拿下来拍到 k 维空间里,重新建树即可。

每次插入一个点相当于二进制加法。由于建树操作本身带 \log,因此插入部分总时间复杂度为 \mathcal{O}(n \log^2 n)

对于查询或修改部分,时间复杂度并没有退化到 \mathcal{O}(\sqrt{n}\log n),因为:

\mathcal{O}(\sum\limits_{i = 0}^{\log n} \sqrt{2^i}) = \mathcal{O}(\sum\limits_{i = 0}^{\log n} \sqrt{2}^i) = \mathcal{O}(\frac{\sqrt{2}^{\log n} - 1}{\sqrt{2} - 1}) = \mathcal{O}(\sqrt{2}^{\log_2n}) = \mathcal{O}(n^{\log_2\sqrt{2}}) = \mathcal{O}(\sqrt{n})

同理可以推广到任意 k,时间复杂度仍然是 \mathcal{O}(n^{\frac{k - 1}{k}})

参考实现

下面给出模板题的参考代码(写的很烂,仅供参考)。

::::info[Code]{open}

提交记录:https://www.luogu.com.cn/record/264963582。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline char gc(){return getchar();}
inline void pc(char ch){putchar(ch);}
inline ll rd(){
    ll x = 0, f = 1;
    char ch = gc();
    while(!isdigit(ch)){
        if(ch == '-') f = -1;
        ch = gc();
    }
    while(isdigit(ch)){
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = gc();
    }
    return x * f;
}
inline void wr(ll x){
    if(x < 0) pc('-'), x = -x;
    if(x > 9) wr(x / 10);
    pc(x % 10 + '0');
    return;
}
ll INF = 1e18 + 114514;
int k; // 维度
struct Node{
    ll l[3] = {0, 0, 0}, r[3] = {INF, INF, INF}; // 管辖范围
    ll d[3]; // 当前点坐标
    int ls, rs; // 左右孩子
    ll sum, tag, val; // 权值和,懒标记,当前点权值
    int sz; // 大小
}t[150005]; int tot = 0; // 动态开点
inline void push_up(int x){
    if(!x) return;
    t[x].sz = 1, t[x].sum = t[x].val;
    for(int i = 0; i < k; i++) t[x].l[i] = t[x].r[i] = t[x].d[i];
    if(t[x].ls){
        t[x].sz += t[t[x].ls].sz, t[x].sum += t[t[x].ls].sum;
        for(int i = 0; i < k; i++){ // 合并
            t[x].l[i] = min(t[x].l[i], t[t[x].ls].l[i]);
            t[x].r[i] = max(t[x].r[i], t[t[x].ls].r[i]);
        }
    }
    if(t[x].rs){
        t[x].sz += t[t[x].rs].sz, t[x].sum += t[t[x].rs].sum;
        for(int i = 0; i < k; i++){
            t[x].l[i] = min(t[x].l[i], t[t[x].rs].l[i]);
            t[x].r[i] = max(t[x].r[i], t[t[x].rs].r[i]);
        }
    }
    return;
}
inline void push_down(int x){
    if(!x || !t[x].tag) return;
    t[t[x].ls].tag += t[x].tag;
    t[t[x].ls].sum += (ll)t[x].tag * t[t[x].ls].sz;
    t[t[x].rs].tag += t[x].tag;
    t[t[x].rs].sum += (ll)t[x].tag * t[t[x].rs].sz;
    t[x].val += t[x].tag, t[x].tag = 0; // 无需再修改 sum
    return;
}
int axis; inline bool cmp(int x, int y){return t[x].d[axis] < t[y].d[axis];}
inline int build(vector<int> &ids, int l, int r, int dep){
    if(l > r) return 0;
    int mid = (l + r) >> 1;
    axis = dep % k;
    nth_element(ids.begin() + l, ids.begin() + mid, ids.begin() + r + 1, cmp); //
    int id = ids[mid];
    t[id].ls = build(ids, l, mid - 1, dep + 1);
    t[id].rs = build(ids, mid + 1, r, dep + 1);
    push_up(id);
    return id;
}
inline int judge(int x, ll l[], ll r[]){
    if(!x) return 0;
    for(int i = 0; i < k; i++){
        if(t[x].r[i] < l[i] || t[x].l[i] > r[i]){ // 无交
            return 0;
        }
    }
    for(int i = 0; i < k; i++){
        if(t[x].l[i] < l[i] || t[x].r[i] > r[i]){ // 不完全包含
            return 1;
        }
    }
    return 2; // 完全包含
}
inline void modify(int p, ll l[], ll r[], ll val){
    int state = judge(p, l, r);
    if(!state) return; // 无交
    if(state == 2){
        t[p].tag += val;
        t[p].sum += (ll)val * t[p].sz;
        return;
    }
    push_down(p);
    bool in = true;
    for(int i = 0; i < k; i++){
        if(t[p].d[i] < l[i] || t[p].d[i] > r[i]){
            in = false;
            break;
        }
    }
    if(in) t[p].val += val, t[p].sum += val;
    modify(t[p].ls, l, r, val), modify(t[p].rs, l, r, val);
    push_up(p);
    return;
}
inline ll query(int p, ll l[], ll r[]){
    int state = judge(p, l, r);
    if(!state) return 0;
    if(state == 2) return t[p].sum;
    push_down(p);
    bool in = true;
    for(int i = 0; i < k; i++){
        if(t[p].d[i] < l[i] || t[p].d[i] > r[i]){
            in = false;
            break;
        }
    }
    return (ll)in * t[p].val + query(t[p].ls, l, r) + query(t[p].rs, l, r);
}
inline void push_out(int p, ll anc, vector<int> &out){
    if(!p) return;
    ll w = anc + t[p].tag;
    t[p].val += w, t[p].tag = 0;
    push_out(t[p].ls, w, out), push_out(t[p].rs, w, out);
    out.push_back(p);
    return;
}
struct KDForest{
    vector<int> ids;
    int root = 0;
    inline void clear(){ids.clear(), root = 0;}
}tree[20]; int mxlog, total;
inline void init(int k, int q){
    while((1 << mxlog) <= q) mxlog++;
    mxlog++;
    total = 0, ::k = k;
    return;
}
inline ll Query(ll l[], ll r[]){
    ll ans = 0;
    for(int bit = 0; bit < mxlog; bit++) ans += query(tree[bit].root, l, r);
    return ans;
}
inline void Modify(ll l[], ll r[], ll val){
    for(int bit = 0; bit < mxlog; bit++) modify(tree[bit].root, l, r, val);
    return;
}
inline void Insert(ll d[], ll val){
    int id = ++total;
    for(int i = 0; i < k; i++) t[id].l[i] = t[id].r[i] = t[id].d[i] = d[i];
    t[id].val = t[id].sum = val, t[id].sz = 1;
    vector<int> cur = {id};
    for(int bit = 0; bit < mxlog; bit++){
        if(tree[bit].ids.empty()){
            tree[bit].ids = move(cur);
            tree[bit].root = build(tree[bit].ids, 0, tree[bit].ids.size() - 1, 0);
            break;
        }else{
            vector<int> collected;
            push_out(tree[bit].root, 0, collected);
            vector<int> merged;
            for(auto p : tree[bit].ids) merged.push_back(p);
            for(auto p :           cur) merged.push_back(p);
            tree[bit].clear(), cur = move(merged);
        }
    }
    return;
}
signed main(){
    k = rd(); int q = rd(); init(k, q);
    ll lst = 0;
    while(q--){
        int op = rd();
        if(op == 1){
            ll x[k];
            for(int i = 0; i < k; i++) x[i] = rd() ^ lst;
            Insert(x, rd() ^ lst);
        }else if(op == 2){
            ll l[k], r[k];
            for(int i = 0; i < k; i++) l[i] = rd() ^ lst;
            for(int i = 0; i < k; i++) r[i] = rd() ^ lst;
            Modify(l, r, rd() ^ lst);
        }else{
            ll l[k], r[k];
            for(int i = 0; i < k; i++) l[i] = rd() ^ lst;
            for(int i = 0; i < k; i++) r[i] = rd() ^ lst;
            wr(lst = Query(l, r)), pc('\n');
        }
    }
    return 0;
}

::::

鸣谢