题解:P14312 【模板】K-D Tree

· · 题解

看到 K-D Tree 有模板题了,来水一发题解。

给了二进制分组的复杂度证明,OI-Wiki 上和洛谷上好像都没有给出过具体的证明(一个式子本蒟蒻看不懂,给一个详细一点的证明)。

K-D Tree

K-D Tree 是什么:

K-D Tree(KDT,K-Dimension Tree)是一种可以高效处理 k 维空间信息的数据结构。
在结点数 n 远大于 2^k 时,应用 K-D Tree 的时间效率很好。

简单来说,K-D Tree 就是一种可以高效处理高维空间中点的数据结构(例如可以解决强制在线的三维偏序),一般比较实用的是 2-D Tree 和 3-D Tree,也就是本题中要实现的。

节点信息

K-D Tree 是一颗二叉搜索树,每个节点是一个点,每棵子树是一个 k 维空间。

每个节点需要存的信息如下:

struct node{
    int x[3];
    int val,sum;
    int ls,rs;
    int l[3],r[3];
    int siz,tag;
}t[N],L,R;

数组 t 是节点,L,R 是查询时矩形的两个顶点。
每个节点中要存这个节点的每维位置,节点的权值,子树的权值和,左右儿子,子树所表示空间的边界,子树大小,修改的懒标记。

建树

下面以二维平面为例,给出 K-D Tree 的建树方法。

首先为了保证平衡,我们应当对每个维度轮流处理,以下面这个图为例:

先对于第一维找到中间的点 D,将平面分为两个部分。
换一个维度,找到两个部分的中点 C,E,将平面分为四个部分,然后以此类推。
最后的树应该长这样: 。

主要操作就是找到一个维度的中点,将点分为左右两部分,可以直接用 nth_element 实现,复杂度 O(n),所以总的建树复杂度为 O(n\log n),且最后的树高是严格 \log n+O(1) 的,代码如下:

int build(int l,int r,int k=0){
    if(l>r) return 0;
    int mid=(l+r)>>1;
    nth_element(a+l,a+mid,a+r+1,[k](int x,int y){
        return t[x].x[k]<t[y].x[k];
    });
    int p=a[mid];
    ls=build(l,mid-1,(k+1)%K);
    rs=build(mid+1,r,(k+1)%K);
    up(p);
    return p;
}

查询

写法非常简单,与查询部分无交直接返回,有部分相交递归子树,全部相交返回处理好的子树权值和,代码如下:

 int query(int p){
    if(!p) return 0;
    for(int k=0;k<K;k++) if(L.x[k]>t[p].r[k] || t[p].l[k]>R.x[k]) return 0;
    bool f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].l[k] && t[p].r[k]<=R.x[k]);
    if(f) return t[p].sum;
    f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].x[k] && t[p].x[k]<=R.x[k]);
    down(p);
    return f*t[p].val+query(ls)+query(rs);
}

实现是简单的,主要讲一下复杂度和证明,以二维为例:

在一个节点,如果递归了左右儿子,那么说明矩形与该节点的矩形部分相交,再考虑这个节点的四个孙子有哪些会再次与矩形部分相交,注意到与矩形部分相交的节点的矩形一定会被矩形的一条边穿过,所以我们将查询矩形的四条边分开来考虑,而一条边(与坐标轴平行)最多穿过这个节点四个孙子的其中两个,这是显然的,所以可以得到:

T(n)=2T(\frac{n}{4})+O(1) T(n)=O(\sqrt n)

容易扩展到 k 维形式:

T(n)=2^{k-1}T(\frac{n}{2^k})+O(1) T(n)=O(n^{1-\frac{1}{k}})

所以查询的复杂度为 O(n^{1-\frac{1}{k}})

修改

和查询几乎同理,代码也差不多,就不讲了,代码如下:

void update(int p){
    if(!p) return;
    for(int k=0;k<K;k++) if(L.x[k]>t[p].r[k] || t[p].l[k]>R.x[k]) return;
    bool f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].l[k] && t[p].r[k]<=R.x[k]);
    if(f){
        t[p].tag+=c;
        t[p].sum+=t[p].siz*c;
        t[p].val+=c;
        return;
    }
    f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].x[k] && t[p].x[k]<=R.x[k]);
    if(f) t[p].val+=c;
    down(p);
    update(ls);
    update(rs);
    up(p);
}

其中 up 为和并节点信息的函数,down 为下传懒标记的函数,都是比较好理解的,就不多讲了,有一个实现的细节,在没有左右儿子的时候为了防止特判,可以对 0 号节点初始化一下:

void up(int p){
    t[p].sum=t[p].val+t[ls].sum+t[rs].sum;
    t[p].siz=1+t[ls].siz+t[rs].siz;
    for(int k=0;k<K;k++){
        t[p].l[k]=min(t[p].x[k],min(t[ls].l[k],t[rs].l[k]));
        t[p].r[k]=max(t[p].x[k],max(t[ls].r[k],t[rs].r[k]));
    }
}
void down(int p){
    if(!t[p].tag) return;
    int x=t[p].tag;
    if(ls) t[ls].tag+=x,t[ls].sum+=x*t[ls].siz,t[ls].val+=x;
    if(rs) t[rs].tag+=x,t[rs].sum+=x*t[rs].siz,t[rs].val+=x;
    t[p].tag=0;
}

插入/删除

删除比较显然,直接标记一下表示删了即可,这题用不到,主要讲一下怎么插入一个点。

首先直接插入显然不对,因为查询操作要保证子树大小严格减半。

比较常用维护方式是替罪羊树维护,根号重构以及二进制分组。

前两者的复杂度可以参考文末的文章,这里不多讲了。

比较推荐的方式是写二进制分组,容易实现复杂度也较优,为 O(n\sqrt n+n\log^2 n)

具体方式是这样的,开 \log n 棵树,大小分别为 2^0,2^1,2^2\cdots,当出现两颗大小相同的树时,合并为一棵新的大小为原来两倍的树,就是二进制的原理,复杂度也比较显然,建树次数是 O(\log n) 的,所以总复杂度为 O(n\log^2 n)

查询和修改时对于每棵树分别查询和修改即可,但这样复杂度为啥是对的?

写出复杂度:

\sum_{i=1}^{\log n}\sqrt{2^i}=\sum_{i=1}^{\log n}2^{\frac{i}{2}}=\sum_{i=1}^{\log n}\sqrt2^i

发现是个等比数列,直接求和:

\sqrt2\times\frac{\sqrt2^{\log n}-1}{\sqrt2-1}=O(\sqrt n)

所以对于每棵子树分别修改查询复杂度不会多 O(\log n),还是 O(\sqrt n) 的,写法还是比较简单的,插入节点代码如下:

a[n=1]=cnt;
for(int i=0;i<LG;i++)
    if(rt[i]) release(rt[i]);
    else{
        rt[i]=build(1,n);
        break;
    }

其中 release 函数为回收节点,这个是简单的:

void release(int &p){
    if(!p) return;
    a[++n]=p;
    down(p);
    release(ls);
    release(rs);
    p=0;
}

代码

给出我丑陋的实现,因为不想写两颗 K-D Tree,用了循环枚举维度,跑的挺慢,实际上可以展开写。

除此之外,K-D Tree 还可以写成线段树的形式,常数会大一点,但是好写不少,可以参考 ERoRaIn大佬的实现。

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=1.5e5+10,LG=__lg(N)+3,inf=1e18;
int m,K;
struct node{
    int x[3];
    int val,sum;
    int ls,rs;
    int l[3],r[3];
    int siz,tag;
}t[N],L,R;
int rt[LG],cnt=0,n=0,a[N],c;
#define ls t[p].ls
#define rs t[p].rs
void up(int p){
    t[p].sum=t[p].val+t[ls].sum+t[rs].sum;
    t[p].siz=1+t[ls].siz+t[rs].siz;
    for(int k=0;k<K;k++){
        t[p].l[k]=min(t[p].x[k],min(t[ls].l[k],t[rs].l[k]));
        t[p].r[k]=max(t[p].x[k],max(t[ls].r[k],t[rs].r[k]));
    }
}
void down(int p){
    if(!t[p].tag) return;
    int x=t[p].tag;
    if(ls) t[ls].tag+=x,t[ls].sum+=x*t[ls].siz,t[ls].val+=x;
    if(rs) t[rs].tag+=x,t[rs].sum+=x*t[rs].siz,t[rs].val+=x;
    t[p].tag=0;
}
void release(int &p){
    if(!p) return;
    a[++n]=p;
    down(p);
    release(ls);
    release(rs);
    p=0;
}
int build(int l,int r,int k=0){
    if(l>r) return 0;
    int mid=(l+r)>>1;
    nth_element(a+l,a+mid,a+r+1,[k](int x,int y){
        return t[x].x[k]<t[y].x[k];
    });
    int p=a[mid];
    ls=build(l,mid-1,(k+1)%K);
    rs=build(mid+1,r,(k+1)%K);
    up(p);
    return p;
}
int query(int p){
    if(!p) return 0;
    for(int k=0;k<K;k++) if(L.x[k]>t[p].r[k] || t[p].l[k]>R.x[k]) return 0;
    bool f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].l[k] && t[p].r[k]<=R.x[k]);
    if(f) return t[p].sum;
    f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].x[k] && t[p].x[k]<=R.x[k]);
    down(p);
    return f*t[p].val+query(ls)+query(rs);
}
void update(int p){
    if(!p) return;
    for(int k=0;k<K;k++) if(L.x[k]>t[p].r[k] || t[p].l[k]>R.x[k]) return;
    bool f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].l[k] && t[p].r[k]<=R.x[k]);
    if(f){
        t[p].tag+=c;
        t[p].sum+=t[p].siz*c;
        t[p].val+=c;
        return;
    }
    f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].x[k] && t[p].x[k]<=R.x[k]);
    if(f) t[p].val+=c;
    down(p);
    update(ls);
    update(rs);
    up(p);
}
#undef ls
#undef rs
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin>>K>>m;
    t[0]={0,0,0,0,0,0,0,inf,inf,inf,0,0,0,0,0};
    for(int op,ans=0;m--;){
        cin>>op;
        if(op==1){
            cnt++;
            for(int i=0;i<K;i++) cin>>t[cnt].x[i],t[cnt].x[i]^=ans;
            cin>>t[cnt].val;t[cnt].val^=ans;
            a[n=1]=cnt;
            for(int i=0;i<LG;i++)
                if(rt[i]) release(rt[i]);
                else{
                    rt[i]=build(1,n);
                    break;
                }
        }
        if(op==2){
            for(int i=0;i<K;i++) cin>>L.x[i],L.x[i]^=ans;
            for(int i=0;i<K;i++) cin>>R.x[i],R.x[i]^=ans;
            cin>>c;c^=ans;
            for(int i=0;i<LG;i++) update(rt[i]);
        }
        if(op==3){
            for(int i=0;i<K;i++) cin>>L.x[i],L.x[i]^=ans;
            for(int i=0;i<K;i++) cin>>R.x[i],R.x[i]^=ans;
            ans=0;
            for(int i=0;i<LG;i++) ans+=query(rt[i]);
            cout<<ans<<"\n";
        }
    }
    return 0;
}

参考资料

OI-Wiki,

线段树式写法(From ERoRaIn),

替罪羊树维护以及根号重构复杂度证明(From command_block)。