K-D Tree 学习笔记
本文参考 OI-wiki。
Part0. 简述
K-D Tree,全名 K-Dimension Tree,译为 K 维树,也就是维护 K 个维度的信息的树。比如如果要维护二维平面信息,可以把 K-D Tree 称作 2-D Tree。
Part1. 基础操作
建树
KDT 上的每一个节点都表示 K 维空间中的一个点。且 KDT 具有类似二叉搜索树的性质,即设定一个阈值,小于该阈值的放在左子树,否则放在右子树。
现在考虑如果有若干 K 维空间中的点,应当对其如何构建 KDT。
- 选择一个维度,并选择一个中间点,设这个中间点在该维度上的值为
B 。 - 将当前点集中该维度
\leq B 的点放入左子树,剩余放入右子树,以选择的中间点为子树的根。 - 以
K=2 为例,交替选择维度0,1 ,如图所示: - 那么可以建出以下树:
- 但是我们想要复杂度尽量小,那么每次选择的中间点和维度就不能乱选。一种尽量平衡树高的方法是交替选择维度
0,1,\dots,K ,每次选择该维度上的中位数。选出[l,r] 中的中位数mid 的方法可以用一个系统函数nth_element(a+l,a+mid,a+r+1,cmp)。
于是可以写出以下代码:
void bd(int &cur,int l,int r,int k){
if(l>r)return cur=0,void();
int mid=l+r>>1;
now=k,nth_element(a+l,a+mid,a+r+1,cmp);
cur=a[mid],dim(cur)=k;
bd(ls(cur),l,mid-1,(k+1)%K),bd(rs(cur),mid+1,r,(k+1)%K),ps(cur);
return;
}
注:ps(cur) 就是 pushup,用于上传信息,由于每个题的 pushup 都不一样所以在此不提。
数点
不妨以
- 其子树内值域与
R 无交,则直接停止递归,返回。 - 其子树内值域被
R 包含,直接返回子树答案。 - 其余情况,则先判断子树根是否被
R 包含,然后递归左右儿子。
不难写出如下代码:
int qy(int cur,int l1,int r1,int l2,int r2){
if(!cur||r1<l(cur)||l1>r(cur)||r2<u(cur)||l2>d(cur))return 0;
if(l(cur)>=l1&&r1>=r(cur)&&u(cur)>=l2&&r2>=d(cur))return sum(cur);
int res=0;
if(tr[cur].x>=l1&&r1>=tr[cur].x&&tr[cur].y>=l2&&r2>=tr[cur].y)res+=tr[cur].v;
return res+qy(ls(cur),l1,r1,l2,r2)+qy(rs(cur),l1,r1,l2,r2);
}
其中 l(cur),r(cur) 指 u(cur),d(cur) 同理。
复杂度被证明是
插入
显然不会只有查询没有操作。考虑插入一个新点,从根开始向下找,类似检索二叉搜索树一样根据中间值和区分维度选择是进入左子树还是右子树,最后成为某个新的叶子节点。
但是这样多了就会出现一个问题,如果反复在某一左或右子树插入节点,那么我们建树时精心构造的平衡树高就会被破坏,从而导致复杂度失衡。这时我们需要重构子树以平衡复杂度。
这里有一个比较好写好用的做法叫做替罪羊式重构。具体来说,设定一个常数
代码:
bool chk(int cur){
if(max(siz(ls(cur)),siz(rs(cur)))>=1.0*alpha*siz(cur))return 1;
return 0;
}
void ret(int cur){if(!cur)return;ret(ls(cur)),aa[++n]=cur,ret(rs(cur));}
void rebd(int &cur){n=0,ret(cur),bd(cur,1,n,0);}
void ins(int &cur,int x){
if(!cur)return cur=x,ps(cur),void();
if(!dim(cur)){
if(tr[x].x<=tr[cur].x)ins(ls(cur),x);
else ins(rs(cur),x);
}
else{
if(tr[x].y<=tr[cur].y)ins(ls(cur),x);
else ins(rs(cur),x);
}
ps(cur);
if(chk(cur))rebd(cur);
return;
}
注:其中 ret 函数是将树拍平成序列操作。不难发现这是中序遍历节点并记录,这样做可以有效保留树上信息至序列上,也方便重构。
:::warning[关于替罪羊树重构]{open} 这个东西的复杂度似乎无法保证而且很玄学,主要是好写。
更优秀的写法应当是二进制分组或是根号重构。 :::
:::info[KD-Tree 的优势?]{open} 这种复杂度较高且玄学的数据结构到底有何优势?
- 需要空间小,仅仅需要的是线性空间。优于树套树等需要大空间的数据结构。
- 能在线,如果遇到强制在线的题目 cdq 或者是 莫队 就无用武之地了,但是 KDT 就能够使用。 :::
根据上述描述,有一道给 KDT 量身定做的例题:
eg.1 P4148 简单题
两个操作:插入单点,二维数点,强制在线,空间只有 20MB。
不难发现就是裸的模板,直接将上述代码整合一下即可。
:::success[代码]
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+5;
const double alpha=0.75;
int n,rt,tot,lst,aa[N];
struct kdt{
int ls,rs,l,r,u,d,siz,sum;
int x,y,v,dim;
#define ls(x) tr[x].ls
#define rs(x) tr[x].rs
#define l(x) tr[x].l
#define r(x) tr[x].r
#define u(x) tr[x].u
#define d(x) tr[x].d
#define siz(x) tr[x].siz
#define sum(x) tr[x].sum
#define dim(x) tr[x].dim
}tr[N];
void chkmin(int &x,int y){if(x>y)x=y;return;}
void chkmax(int &x,int y){if(x<y)x=y;return;}
void ps(int cur){
siz(cur)=siz(ls(cur))+siz(rs(cur))+1;
sum(cur)=sum(ls(cur))+sum(rs(cur))+tr[cur].v;
l(cur)=r(cur)=tr[cur].x,u(cur)=d(cur)=tr[cur].y;
if(ls(cur))
chkmin(l(cur),l(ls(cur))),chkmax(r(cur),r(ls(cur))),
chkmin(u(cur),u(ls(cur))),chkmax(d(cur),d(ls(cur)));
if(rs(cur))
chkmin(l(cur),l(rs(cur))),chkmax(r(cur),r(rs(cur))),
chkmin(u(cur),u(rs(cur))),chkmax(d(cur),d(rs(cur)));
return;
}
bool chk(int cur){
if(max(siz(ls(cur)),siz(rs(cur)))>=1.0*alpha*siz(cur))return 1;
return 0;
}
bool cmpx(int x,int y){return tr[x].x<tr[y].x;}
bool cmpy(int x,int y){return tr[x].y<tr[y].y;}
void bd(int &cur,int l,int r,int k){
if(l>r)return cur=0,void();
int mid=l+r>>1;
if(!k)nth_element(aa+l,aa+mid,aa+r+1,cmpx);
else nth_element(aa+l,aa+mid,aa+r+1,cmpy);
cur=aa[mid],dim(cur)=k;
bd(ls(cur),l,mid-1,k^1),bd(rs(cur),mid+1,r,k^1),ps(cur);
return;
}
void ret(int cur){if(!cur)return;ret(ls(cur)),aa[++n]=cur,ret(rs(cur));}
void rebd(int &cur){n=0,ret(cur),bd(cur,1,n,0);}
void ins(int &cur,int x){
if(!cur)return cur=x,ps(cur),void();
if(!dim(cur)){
if(tr[x].x<=tr[cur].x)ins(ls(cur),x);
else ins(rs(cur),x);
}
else{
if(tr[x].y<=tr[cur].y)ins(ls(cur),x);
else ins(rs(cur),x);
}
ps(cur);
if(chk(cur))rebd(cur);
return;
}
int qy(int cur,int l1,int r1,int l2,int r2){
if(!cur||r1<l(cur)||l1>r(cur)||r2<u(cur)||l2>d(cur))return 0;
if(l(cur)>=l1&&r1>=r(cur)&&u(cur)>=l2&&r2>=d(cur))return sum(cur);
int res=0;
if(tr[cur].x>=l1&&r1>=tr[cur].x&&tr[cur].y>=l2&&r2>=tr[cur].y)res+=tr[cur].v;
return res+qy(ls(cur),l1,r1,l2,r2)+qy(rs(cur),l1,r1,l2,r2);
}
int main(){
int qwq;cin>>qwq;
while(1){
int op,a,b,c,d;
cin>>op;
if(op==3)break;
cin>>a>>b>>c,a^=lst,b^=lst,c^=lst;
if(op==1)tot++,tr[tot].x=a,tr[tot].y=b,tr[tot].v=c,ins(rt,tot);
else cin>>d,d^=lst,lst=qy(rt,a,c,b,d),cout<<lst<<"\n";
}
return 0;
}
:::
:::info[本题的双倍经验]{open} P4390 [BalkanOI 2007] Mokia 摩基亚 :::
Part2. 其他操作
引入区间操作
也就是从单点操作区间查变为了区间操作区间查。与线段树类似,我们同样也要引入懒标记,由于 K-D Tree 的二叉树结构和线段树存在类似之处,所以我们可以比较方便的维护懒标记。
唯一一点需要注意的是,当我们进行替罪羊树重构的时候,所有父子关系可能会被打乱,那么此时应当将重构子树的所有标记先全部下传,再重构。其余似乎和线段树懒标记大同小异。
eg.2 P14312 【模板】K-D Tree
似乎就是加入区间加的 eg.1。
写代码的整体思路如上所示,主要需要注意的是一些细节上的处理。比如 pushdown 的时机,是否需要 pushup 等等。
比如,在 insert 一个新的点时应当将其经过的所有点都 pushdown 懒标记,如果不 pushdown 的话等该新点加入子树后就能加入该懒标记的贡献,此时贡献就错误了。
代码中这样具体的细节还挺多的。
:::success[代码]
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2e5+5;
const double alpha=0.75;
int K,Q,n,rt,tot,lst,a[N],now;
struct kdt{
int ls,rs,siz,sum;
int x[3],v,dim,tg;
int mx[3],mn[3];
#define ls(x) tr[x].ls
#define rs(x) tr[x].rs
#define siz(x) tr[x].siz
#define sum(x) tr[x].sum
#define dim(x) tr[x].dim
#define tg(x) tr[x].tg
}tr[N];
void chkmin(int &x,int y){if(x>y)x=y;return;}
void chkmax(int &x,int y){if(x<y)x=y;return;}
void ps(int cur){
siz(cur)=siz(ls(cur))+siz(rs(cur))+1;
sum(cur)=sum(ls(cur))+sum(rs(cur))+tr[cur].v;
for(int i=0;i<K;i++){
tr[cur].mx[i]=tr[cur].mn[i]=tr[cur].x[i];
if(ls(cur))
chkmax(tr[cur].mx[i],tr[ls(cur)].mx[i]),
chkmin(tr[cur].mn[i],tr[ls(cur)].mn[i]);
if(rs(cur))
chkmax(tr[cur].mx[i],tr[rs(cur)].mx[i]),
chkmin(tr[cur].mn[i],tr[rs(cur)].mn[i]);
}
return;
}
bool chk(int cur){
if(max(siz(ls(cur)),siz(rs(cur)))>=1.0*alpha*siz(cur))return 1;
return 0;
}
bool cmp(int x,int y){return tr[x].x[now]<tr[y].x[now];}
void bd(int &cur,int l,int r,int k){
if(l>r)return cur=0,void();
int mid=l+r>>1;
now=k,nth_element(a+l,a+mid,a+r+1,cmp);
cur=a[mid],dim(cur)=k;
bd(ls(cur),l,mid-1,(k+1)%K),bd(rs(cur),mid+1,r,(k+1)%K),ps(cur);
return;
}
void ad(int cur,int val){tg(cur)+=val,sum(cur)+=val*siz(cur),tr[cur].v+=val;}
void alpd(int cur){
if(ls(cur))ad(ls(cur),tg(cur));
if(rs(cur))ad(rs(cur),tg(cur));
tg(cur)=0;
if(ls(cur))alpd(ls(cur));
if(rs(cur))alpd(rs(cur));
return;
}
void pd(int cur){
if(!tg(cur))return;
if(ls(cur))ad(ls(cur),tg(cur));
if(rs(cur))ad(rs(cur),tg(cur));
return tg(cur)=0,void();
}
void ret(int cur){if(!cur)return;ret(ls(cur)),a[++n]=cur,ret(rs(cur));}
void rebd(int &cur){n=0,alpd(cur),ret(cur),bd(cur,1,n,0);}
void ins(int &cur,int x){
if(!cur)return cur=x,ps(cur),void();
pd(cur);
if(tr[x].x[dim(cur)]<=tr[cur].x[dim(cur)])ins(ls(cur),x);
else ins(rs(cur),x);
ps(cur);
if(chk(cur))rebd(cur);
return;
}
int qyk2(int cur,int l1,int r1,int l2,int r2){
if(!cur||r1<tr[cur].mn[0]||l1>tr[cur].mx[0]||r2<tr[cur].mn[1]||l2>tr[cur].mx[1])return 0;
if(l1<=tr[cur].mn[0]&&tr[cur].mx[0]<=r1&&l2<=tr[cur].mn[1]&&tr[cur].mx[1]<=r2)return sum(cur);
int res=0;
if(l1<=tr[cur].x[0]&&tr[cur].x[0]<=r1&&l2<=tr[cur].x[1]&&tr[cur].x[1]<=r2)res=tr[cur].v;
return pd(cur),res+qyk2(ls(cur),l1,r1,l2,r2)+qyk2(rs(cur),l1,r1,l2,r2);
}
void updk2(int cur,int l1,int r1,int l2,int r2,int val){
if(!cur||r1<tr[cur].mn[0]||l1>tr[cur].mx[0]||r2<tr[cur].mn[1]||l2>tr[cur].mx[1])return;
if(l1<=tr[cur].mn[0]&&tr[cur].mx[0]<=r1&&l2<=tr[cur].mn[1]&&tr[cur].mx[1]<=r2)
return ad(cur,val),void();
if(l1<=tr[cur].x[0]&&tr[cur].x[0]<=r1&&l2<=tr[cur].x[1]&&tr[cur].x[1]<=r2)tr[cur].v+=val;
pd(cur),updk2(ls(cur),l1,r1,l2,r2,val),updk2(rs(cur),l1,r1,l2,r2,val),ps(cur);
return;
}
int qyk3(int cur,int l1,int r1,int l2,int r2,int l3,int r3){
if(!cur||r1<tr[cur].mn[0]||l1>tr[cur].mx[0]||r2<tr[cur].mn[1]||l2>tr[cur].mx[1]||r3<tr[cur].mn[2]||l3>tr[cur].mx[2])return 0;
if(l1<=tr[cur].mn[0]&&tr[cur].mx[0]<=r1&&l2<=tr[cur].mn[1]&&tr[cur].mx[1]<=r2&&l3<=tr[cur].mn[2]&&tr[cur].mx[2]<=r3)return sum(cur);
int res=0;
if(l1<=tr[cur].x[0]&&tr[cur].x[0]<=r1&&l2<=tr[cur].x[1]&&tr[cur].x[1]<=r2&&l3<=tr[cur].x[2]&&tr[cur].x[2]<=r3)res=tr[cur].v;
return pd(cur),res+qyk3(ls(cur),l1,r1,l2,r2,l3,r3)+qyk3(rs(cur),l1,r1,l2,r2,l3,r3);
}
void updk3(int cur,int l1,int r1,int l2,int r2,int l3,int r3,int val){
if(!cur||r1<tr[cur].mn[0]||l1>tr[cur].mx[0]||r2<tr[cur].mn[1]||l2>tr[cur].mx[1]||r3<tr[cur].mn[2]||l3>tr[cur].mx[2])return;
if(l1<=tr[cur].mn[0]&&tr[cur].mx[0]<=r1&&l2<=tr[cur].mn[1]&&tr[cur].mx[1]<=r2&&l3<=tr[cur].mn[2]&&tr[cur].mx[2]<=r3)
return ad(cur,val),void();
if(l1<=tr[cur].x[0]&&tr[cur].x[0]<=r1&&l2<=tr[cur].x[1]&&tr[cur].x[1]<=r2&&l3<=tr[cur].x[2]&&tr[cur].x[2]<=r3)tr[cur].v+=val;
pd(cur),updk3(ls(cur),l1,r1,l2,r2,l3,r3,val),updk3(rs(cur),l1,r1,l2,r2,l3,r3,val),ps(cur);
return;
}
signed main(){
// freopen("qwq.in","r",stdin);
// freopen("1.out","w",stdout);
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>K>>Q;
while(Q--){
int op;cin>>op;
if(op==1){
int v;tot++;
for(int i=1,x;i<=K;i++)cin>>x,x^=lst,tr[tot].x[i-1]=x;
cin>>v,v^=lst,tr[tot].v=v,ins(rt,tot);
}
else if(op==2){
int b[3],c[3],v;
for(int i=0;i<K;i++)cin>>b[i],b[i]^=lst;
for(int i=0;i<K;i++)cin>>c[i],c[i]^=lst;
cin>>v,v^=lst;
if(K==2)updk2(rt,b[0],c[0],b[1],c[1],v);
else updk3(rt,b[0],c[0],b[1],c[1],b[2],c[2],v);
}
else{
int b[3],c[3];
for(int i=0;i<K;i++)cin>>b[i],b[i]^=lst;
for(int i=0;i<K;i++)cin>>c[i],c[i]^=lst;
if(K==2)lst=qyk2(rt,b[0],c[0],b[1],c[1]),cout<<lst<<"\n";
else lst=qyk3(rt,b[0],c[0],b[1],c[1],b[2],c[2]),cout<<lst<<"\n";
}
}
return 0;
}
似乎因为一些压行导致显示出来的效果有些诡异啊/kk。 :::
最近点对和估价函数
K-D Tree 另一个很常用的作用是求平面中关于一个点,离它最近的或是最远的点。
以一个题为例:
eg.3 P2479 [SDOI2010] 捉迷藏
考虑先建出 2-D Tree。枚举每个节点,计算离其最远的点和最近的点的距离(注意是曼哈顿距离),然后计算差值最小值。
那么首先考虑如何对于给定的点计算离其最远的点,那么我们需要的是快速确定我们应当检索当前的左子树还是右子树。
但是这并不好处理。而如果使用遍历整个左右子树得到信息的方式,复杂度显然是不优的。于是我们考虑设计估价函数:计算一棵子树内可能产生与查询点的最大距离。
具体的,我们可以维护一棵子树上两个维度的最大 / 最小值
int g1(int cur,int x){
return max(abs(tr[x].x[0]-tr[cur].mx[0]),abs(tr[x].x[0]-tr[cur].mn[0]))
+max(abs(tr[x].x[1]-tr[cur].mx[1]),abs(tr[x].x[1]-tr[cur].mn[1]));
}
当然,这只是理想中最大的距离,只有当有一个点同时满足横纵坐标都顶到了这个的上界才存在。所以,这只是估价函数,而非计算函数。
那么这有什么用呢?可以先简单计算出当前点的左右子树的估价函数值
代码如下:
void qy1(int cur,int x){
if(!cur)return;
if(x!=cur)chkmax(maxn,calc(tr[x],tr[cur]));
int ld=g1(ls(cur),x),rd=g1(rs(cur),x);
if(ld>rd){
if(ld>maxn)qy1(ls(cur),x);
if(rd>maxn)qy1(rs(cur),x);
}
else{
if(rd>maxn)qy1(rs(cur),x);
if(ld>maxn)qy1(ls(cur),x);
}
return;
}
设计最小距离估价函数是同理的。读者可尝试自行设计。
:::warning[关于时间复杂度]{open}
这种设计估价函数的方式本质上是一种最优化剪枝而非实质性的优化复杂度。也就是说,本做法的时间复杂度仍然是单次
所以这种做法更多的是用于在比赛中骗取更多的部分分或是在随机数据下进行求解。 :::
:::success[本题完整代码]
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int n,now,maxn,minn,ans=1e18;
struct node{
int l,r;
int mx[2],mn[2],x[2];
bool operator<(const node&X)const{return x[now]<X.x[now];}
}tr[N];
void chkmax(int &x,int y){if(x<y)x=y;return;}
void chkmin(int &x,int y){if(x>y)x=y;return;}
void ps(int cur){
for(int i=0;i<2;i++){
tr[cur].mx[i]=tr[cur].mn[i]=tr[cur].x[i];
if(tr[cur].l)
tr[cur].mx[i]=max(tr[cur].mx[i],tr[tr[cur].l].mx[i]),
tr[cur].mn[i]=min(tr[cur].mn[i],tr[tr[cur].l].mn[i]);
if(tr[cur].r)
tr[cur].mx[i]=max(tr[cur].mx[i],tr[tr[cur].r].mx[i]),
tr[cur].mn[i]=min(tr[cur].mn[i],tr[tr[cur].r].mn[i]);
}
return;
}
int bd(int l,int r,int op){
if(l>r)return 0;
int mid=l+r>>1;
now=op,nth_element(tr+l,tr+mid,tr+r+1);
tr[mid].l=bd(l,mid-1,op^1),tr[mid].r=bd(mid+1,r,op^1);
ps(mid);return mid;
}
int g1(int cur,int x){
return max(abs(tr[x].x[0]-tr[cur].mx[0]),abs(tr[x].x[0]-tr[cur].mn[0]))
+max(abs(tr[x].x[1]-tr[cur].mx[1]),abs(tr[x].x[1]-tr[cur].mn[1]));
}
int g2(int cur,int x){
int s0=0,s1=0;
if(tr[x].x[0]<tr[cur].mn[0])s0=tr[cur].mn[0]-tr[x].x[0];
if(tr[x].x[0]>tr[cur].mx[0])s0=tr[x].x[0]-tr[cur].mx[0];
if(tr[x].x[1]<tr[cur].mn[1])s1=tr[cur].mn[1]-tr[x].x[1];
if(tr[x].x[1]>tr[cur].mx[1])s1=tr[x].x[1]-tr[cur].mx[1];
return s0+s1;
}
int calc(node x,node y){return abs(x.x[0]-y.x[0])+abs(x.x[1]-y.x[1]);}
void qy1(int cur,int x){
if(!cur)return;
if(x!=cur)chkmax(maxn,calc(tr[x],tr[cur]));
int ld=g1(tr[cur].l,x),rd=g1(tr[cur].r,x);
if(ld>rd){
if(ld>maxn)qy1(tr[cur].l,x);
if(rd>maxn)qy1(tr[cur].r,x);
}
else{
if(rd>maxn)qy1(tr[cur].r,x);
if(ld>maxn)qy1(tr[cur].l,x);
}
return;
}
void qy2(int cur,int x){
if(!cur)return;
if(x!=cur)chkmin(minn,calc(tr[x],tr[cur]));
int ld=g2(tr[cur].l,x),rd=g2(tr[cur].r,x);
if(ld<rd){
if(ld<minn)qy2(tr[cur].l,x);
if(rd<minn)qy2(tr[cur].r,x);
}
else{
if(rd<minn)qy2(tr[cur].r,x);
if(ld<minn)qy2(tr[cur].l,x);
}
return;
}
signed main(){
cin>>n;
for(int i=1;i<=n;i++)cin>>tr[i].x[0]>>tr[i].x[1];
int rt=bd(1,n,0);
for(int i=1;i<=n;i++){
maxn=0,minn=1e18;
qy1(rt,i),qy2(rt,i);
ans=min(ans,maxn-minn);
}
cout<<ans;
return 0;
}
:::
Part3. 一些练习
eg.4 P6224 [BJWC2014] 数据
求解最近最远点对加入了单点插入。其实就是 eg.1 与 eg.3 的缝合,个人感觉没什么好说的。
:::success[代码]
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2e5+5,inf=1e18;
const double alpha=0.75;
int n,q,rt,tot,now,m,a[N],maxn,minn;
struct node{
int ls,rs,d,siz;
int mx[2],mn[2],x[2];
#define ls(x) tr[x].ls
#define rs(x) tr[x].rs
#define siz(x) tr[x].siz
#define d(x) tr[x].d
}tr[N];
void chkmax(int &x,int y){if(x<y)x=y;return;}
void chkmin(int &x,int y){if(x>y)x=y;return;}
bool cmp(int x,int y){return tr[x].x[now]<tr[y].x[now];}
bool chk(int cur){
if(max(siz(ls(cur)),siz(rs(cur)))>=1.0*alpha*siz(cur))return 1;
return 0;
}
void ps(int cur){
siz(cur)=siz(ls(cur))+siz(rs(cur))+1;
for(int i=0;i<2;i++){
tr[cur].mx[i]=tr[cur].mn[i]=tr[cur].x[i];
if(ls(cur))
tr[cur].mx[i]=max(tr[cur].mx[i],tr[ls(cur)].mx[i]),
tr[cur].mn[i]=min(tr[cur].mn[i],tr[ls(cur)].mn[i]);
if(rs(cur))
tr[cur].mx[i]=max(tr[cur].mx[i],tr[rs(cur)].mx[i]),
tr[cur].mn[i]=min(tr[cur].mn[i],tr[rs(cur)].mn[i]);
}
return;
}
void bd(int &cur,int l,int r,int k){
if(l>r)return cur=0,void();
int mid=l+r>>1;
now=k,nth_element(a+l,a+mid,a+r+1,cmp);
cur=a[mid],d(cur)=k;
bd(ls(cur),l,mid-1,k^1),bd(rs(cur),mid+1,r,k^1),ps(cur);
return;
}
void ret(int cur){if(!cur)return;ret(ls(cur)),a[++n]=cur,ret(rs(cur));}
void rebd(int &cur){n=0,ret(cur),bd(cur,1,n,0);}
void ins(int &cur,int x){
if(!cur)return cur=x,ps(cur),void();
if(tr[x].x[d(cur)]<=tr[cur].x[d(cur)])ins(ls(cur),x);
else ins(rs(cur),x);
ps(cur);
if(chk(cur))rebd(cur);
return;
}
int g1(int cur,int x){
return max(abs(tr[x].x[0]-tr[cur].mx[0]),abs(tr[x].x[0]-tr[cur].mn[0]))
+max(abs(tr[x].x[1]-tr[cur].mx[1]),abs(tr[x].x[1]-tr[cur].mn[1]));
}
int g2(int cur,int x){
int s0=0,s1=0;
if(tr[x].x[0]<tr[cur].mn[0])s0=tr[cur].mn[0]-tr[x].x[0];
if(tr[x].x[0]>tr[cur].mx[0])s0=tr[x].x[0]-tr[cur].mx[0];
if(tr[x].x[1]<tr[cur].mn[1])s1=tr[cur].mn[1]-tr[x].x[1];
if(tr[x].x[1]>tr[cur].mx[1])s1=tr[x].x[1]-tr[cur].mx[1];
return s0+s1;
}
int calc(node x,node y){return abs(x.x[0]-y.x[0])+abs(x.x[1]-y.x[1]);}
void qy1(int cur,int x){
if(!cur)return;
chkmax(maxn,calc(tr[x],tr[cur]));
int ld=g1(ls(cur),x),rd=g1(rs(cur),x);
if(ld>rd){
if(ld>maxn)qy1(ls(cur),x);
if(rd>maxn)qy1(rs(cur),x);
}
else{
if(rd>maxn)qy1(rs(cur),x);
if(ld>maxn)qy1(ls(cur),x);
}
return;
}
void qy2(int cur,int x){
if(!cur)return;
chkmin(minn,calc(tr[x],tr[cur]));
int ld=g2(ls(cur),x),rd=g2(rs(cur),x);
if(ld<rd){
if(ld<minn)qy2(ls(cur),x);
if(rd<minn)qy2(rs(cur),x);
}
else{
if(rd<minn)qy2(rs(cur),x);
if(ld<minn)qy2(ls(cur),x);
}
return;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n,tot=n;
for(int i=1;i<=n;i++)cin>>tr[i].x[0]>>tr[i].x[1],a[i]=i;
bd(rt,1,n,0),cin>>q;
while(q--){
int op,l,r;
cin>>op>>l>>r,tr[++tot].x[0]=l,tr[tot].x[1]=r;
if(!op)ins(rt,tot);
else if(op==1)minn=inf,qy2(rt,tot),cout<<minn<<"\n";
else maxn=-inf,qy1(rt,tot),cout<<maxn<<"\n";
}
return 0;
}
:::
:::info[本题的双倍经验]{open} P4169 [Violet] 天使玩偶/SJY摆棋子
:::
eg.5 P6514 [QkOI#R1] Quark and Strings
简单变形。题目事实上可以理解为插入一些线段,询问一段区间被多少段线段完全包含。
那么将线段
作为练习,我们使用 KDT 维护这个东西。当然,存在其他码量更小,复杂度更优的数据结构维护,在此不提。
:::success[代码]
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+5;
const double alpha=0.75;
int n,m,q,rt,tot,aa[N];
struct kdt{
int ls,rs,l,r,u,d,siz,sum;
int x,y,v,dim;
#define ls(x) tr[x].ls
#define rs(x) tr[x].rs
#define l(x) tr[x].l
#define r(x) tr[x].r
#define u(x) tr[x].u
#define d(x) tr[x].d
#define siz(x) tr[x].siz
#define sum(x) tr[x].sum
#define dim(x) tr[x].dim
}tr[N];
void chkmin(int &x,int y){if(x>y)x=y;return;}
void chkmax(int &x,int y){if(x<y)x=y;return;}
void ps(int cur){
siz(cur)=siz(ls(cur))+siz(rs(cur))+1;
sum(cur)=sum(ls(cur))+sum(rs(cur))+tr[cur].v;
l(cur)=r(cur)=tr[cur].x,u(cur)=d(cur)=tr[cur].y;
if(ls(cur))
chkmin(l(cur),l(ls(cur))),chkmax(r(cur),r(ls(cur))),
chkmin(u(cur),u(ls(cur))),chkmax(d(cur),d(ls(cur)));
if(rs(cur))
chkmin(l(cur),l(rs(cur))),chkmax(r(cur),r(rs(cur))),
chkmin(u(cur),u(rs(cur))),chkmax(d(cur),d(rs(cur)));
return;
}
bool chk(int cur){
if(max(siz(ls(cur)),siz(rs(cur)))>=1.0*alpha*siz(cur))return 1;
return 0;
}
bool cmpx(int x,int y){return tr[x].x<tr[y].x;}
bool cmpy(int x,int y){return tr[x].y<tr[y].y;}
void bd(int &cur,int l,int r,int k){
if(l>r)return cur=0,void();
int mid=l+r>>1;
if(!k)nth_element(aa+l,aa+mid,aa+r+1,cmpx);
else nth_element(aa+l,aa+mid,aa+r+1,cmpy);
cur=aa[mid],dim(cur)=k;
bd(ls(cur),l,mid-1,k^1),bd(rs(cur),mid+1,r,k^1),ps(cur);
return;
}
void ret(int cur){if(!cur)return;ret(ls(cur)),aa[++n]=cur,ret(rs(cur));}
void rebd(int &cur){n=0,ret(cur),bd(cur,1,n,0);}
void ins(int &cur,int x){
if(!cur)return cur=x,ps(cur),void();
if(!dim(cur)){
if(tr[x].x<=tr[cur].x)ins(ls(cur),x);
else ins(rs(cur),x);
}
else{
if(tr[x].y<=tr[cur].y)ins(ls(cur),x);
else ins(rs(cur),x);
}
ps(cur);
if(chk(cur))rebd(cur);
return;
}
int qy(int cur,int l1,int r1,int l2,int r2){
if(!cur||r1<l(cur)||l1>r(cur)||r2<u(cur)||l2>d(cur))return 0;
if(l(cur)>=l1&&r1>=r(cur)&&u(cur)>=l2&&r2>=d(cur))return sum(cur);
int res=0;
if(tr[cur].x>=l1&&r1>=tr[cur].x&&tr[cur].y>=l2&&r2>=tr[cur].y)res+=tr[cur].v;
return res+qy(ls(cur),l1,r1,l2,r2)+qy(rs(cur),l1,r1,l2,r2);
}
int main(){
cin>>m>>q;
while(q--){
int op,l,r;
cin>>op>>l>>r;
if(op==1)tot++,tr[tot].x=l,tr[tot].y=r,tr[tot].v=1,ins(rt,tot);
else cout<<qy(rt,1,l,r,m)<<"\n";
}
return 0;
}
:::
eg.6 P4475 巧克力王国
自主设计一个简单的估价函数。
设牛奶和可可分别为两个维度,建立 2-D Tree。
同样维护一棵子树内两个维度的最大 / 最小值
:::success[代码]
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e4+5;
int n,m,now,ans;
struct node{
int l,r,h,sum;
int v[2],mx[2],mn[2];
bool operator<(const node&X)const{return v[now]<X.v[now];}
}tr[N];
struct query{int a,b,c;}q[N];
void ps(int cur){
for(int i=0;i<2;i++){
tr[cur].mx[i]=tr[cur].mn[i]=tr[cur].v[i];
if(tr[cur].l)
tr[cur].mx[i]=max(tr[cur].mx[i],tr[tr[cur].l].mx[i]),
tr[cur].mn[i]=min(tr[cur].mn[i],tr[tr[cur].l].mn[i]);
if(tr[cur].r)
tr[cur].mx[i]=max(tr[cur].mx[i],tr[tr[cur].r].mx[i]),
tr[cur].mn[i]=min(tr[cur].mn[i],tr[tr[cur].r].mn[i]);
}
tr[cur].sum+=tr[tr[cur].l].sum+tr[tr[cur].r].sum;
return;
}
int bd(int l,int r,int op){
if(l>r)return 0;
int mid=l+r>>1;
now=op,nth_element(tr+l,tr+mid,tr+r+1),tr[mid].sum=tr[mid].h;
tr[mid].l=bd(l,mid-1,op^1),tr[mid].r=bd(mid+1,r,op^1),ps(mid);
return mid;
}
int calc(int cur,int x,int y){return x*q[cur].a+y*q[cur].b<q[cur].c;}
void qy(int cur,int x){
if(!x)return;
int tmp=calc(cur,tr[x].mx[0],tr[x].mx[1])+calc(cur,tr[x].mn[0],tr[x].mx[1])+
calc(cur,tr[x].mx[0],tr[x].mn[1])+calc(cur,tr[x].mn[0],tr[x].mn[1]);
if(tmp==4){ans+=tr[x].sum;return;}
else if(!tmp)return;
if(calc(cur,tr[x].v[0],tr[x].v[1]))ans+=tr[x].h;
qy(cur,tr[x].l),qy(cur,tr[x].r);
return;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>tr[i].v[0]>>tr[i].v[1]>>tr[i].h;
int rt=bd(1,n,0);
for(int i=1;i<=m;i++)cin>>q[i].a>>q[i].b>>q[i].c,qy(i,rt),cout<<ans<<"\n",ans=0;
return 0;
}
:::
eg.7 P2093 [国家集训队] JZPFAR
最远点对变成了 K 远点对。
估价函数大致和最远点对没区别,只是曼哈顿距离变成了欧几里得距离。对于 K 远点对,考虑设置一个大小为
- 如果堆大小
<K ,直接加入当前点。 - 否则:
- 如果当前点比堆顶优,弹出堆顶,加入当前点。
- 如果当前点子树可能取到的最大值也就是估价函数都没有堆顶优,那么当前点子树内都是不优的点,直接返回。
- 否则剩余情况继续向下递归,还是按照左右儿子的估价函数决定先递归哪一边以起到剪枝效果。
数据是随机的,那么复杂度是
:::success[代码]
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int n,m,now,px,py,k;
struct node{
int l,r,id;
int mx[2],mn[2],v[2];
bool operator<(const node&X)const{return v[now]<X.v[now];}
}tr[N];
struct st{
int id,val;
bool operator<(const st&X)const{
if(val==X.val)return id<X.id;
return val>X.val;
}
};
priority_queue<st>q;
void ps(int cur){
for(int i=0;i<2;i++){
tr[cur].mx[i]=tr[cur].mn[i]=tr[cur].v[i];
if(tr[cur].l)
tr[cur].mx[i]=max(tr[cur].mx[i],tr[tr[cur].l].mx[i]),
tr[cur].mn[i]=min(tr[cur].mn[i],tr[tr[cur].l].mn[i]);
if(tr[cur].r)
tr[cur].mx[i]=max(tr[cur].mx[i],tr[tr[cur].r].mx[i]),
tr[cur].mn[i]=min(tr[cur].mn[i],tr[tr[cur].r].mn[i]);
}
return;
}
int bd(int l,int r,int op){
if(l>r)return 0;
int mid=l+r>>1;
now=op,nth_element(tr+l,tr+mid,tr+r+1);
tr[mid].l=bd(l,mid-1,op^1),tr[mid].r=bd(mid+1,r,op^1);
ps(mid);return mid;
}
int _dis(int x){return (tr[x].v[0]-px)*(tr[x].v[0]-px)+(tr[x].v[1]-py)*(tr[x].v[1]-py);}
int calc(int x){
return max({(tr[x].mx[0]-px)*(tr[x].mx[0]-px)+(tr[x].mx[1]-py)*(tr[x].mx[1]-py),
(tr[x].mx[0]-px)*(tr[x].mx[0]-px)+(tr[x].mn[1]-py)*(tr[x].mn[1]-py),
(tr[x].mn[0]-px)*(tr[x].mn[0]-px)+(tr[x].mx[1]-py)*(tr[x].mx[1]-py),
(tr[x].mn[0]-px)*(tr[x].mn[0]-px)+(tr[x].mn[1]-py)*(tr[x].mn[1]-py)});
}
void qy(int x){
if(!x)return;
if((int)q.size()<k)q.push({tr[x].id,_dis(x)});
else{
if(q.top().val<_dis(x)||q.top().val==_dis(x)&&q.top().id>tr[x].id)
q.pop(),q.push({tr[x].id,_dis(x)});
else if(calc(x)<q.top().val)return;
}
if(calc(tr[x].l)>calc(tr[x].r))qy(tr[x].l),qy(tr[x].r);
else qy(tr[x].r),qy(tr[x].l);
return;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n;
for(int i=1;i<=n;i++)cin>>tr[i].v[0]>>tr[i].v[1],tr[i].id=i;
int rt=bd(1,n,0);
cin>>m;
while(m--){
cin>>px>>py>>k;
while(!q.empty())q.pop();
qy(rt),cout<<q.top().id<<"\n";
// while(!q.empty())cout<<q.top().id<<"qwq\n",q.pop();
}
return 0;
}
:::
:::info[本题的双倍经验]{open} P4357 [CQOI2016] K 远点对
需要注意的是,这个题正解并不是 KDT 而是旋转卡壳(我也不会),由于并没有保证数据随机所以最坏复杂度是
eg.8 P3769 [CH弱省胡策R2] TATT
四维偏序板子题。
可以先排序排掉一维
具体地,设
复杂度是
:::success[代码]
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e4+5,inf=1e18;
int n,rt,K,now,ans,f[N],a[N];
struct qwq{int x[4];}w[N];
struct node{
int ls,rs,x[3];
int mx[3],mn[3],v,dat;
#define ls(x) tr[x].ls
#define rs(x) tr[x].rs
#define dat(x) tr[x].dat
}tr[N];
bool cmp(qwq x,qwq y){
if(x.x[0]!=y.x[0])return x.x[0]<y.x[0];
if(x.x[1]!=y.x[1])return x.x[1]<y.x[1];
if(x.x[2]!=y.x[2])return x.x[2]<y.x[2];
return x.x[3]<y.x[3];
}
bool cmpt(int x,int y){return tr[x].x[now]<tr[y].x[now];}
void chkmin(int &x,int y){if(x>y)x=y;return;}
void chkmax(int &x,int y){if(x<y)x=y;return;}
void ps(int cur){
dat(cur)=max(max(dat(ls(cur)),dat(rs(cur))),tr[cur].v);
for(int i=0;i<3;i++){
tr[cur].mx[i]=tr[cur].mn[i]=tr[cur].x[i];
if(ls(cur))
chkmax(tr[cur].mx[i],tr[ls(cur)].mx[i]),
chkmin(tr[cur].mn[i],tr[ls(cur)].mn[i]);
if(rs(cur))
chkmax(tr[cur].mx[i],tr[rs(cur)].mx[i]),
chkmin(tr[cur].mn[i],tr[rs(cur)].mn[i]);
}
return;
}
void bd(int &cur,int l,int r,int k){
if(l>r)return cur=0,void();
int mid=l+r>>1;
now=k,nth_element(a+l,a+mid,a+r+1,cmpt),cur=a[mid];
bd(ls(cur),l,mid-1,(k+1)%K),bd(rs(cur),mid+1,r,(k+1)%K),ps(cur);
return;
}
void upd(int cur,int x,int y,int z,int val){
if(!cur)return;
if(tr[cur].mx[0]<x||tr[cur].mn[0]>x||tr[cur].mx[1]<y||tr[cur].mn[1]>y||tr[cur].mx[2]<z||tr[cur].mn[2]>z)return;
if(tr[cur].x[0]==x&&tr[cur].x[1]==y&&tr[cur].x[2]==z)
return chkmax(tr[cur].v,val),ps(cur),void();
upd(ls(cur),x,y,z,val),upd(rs(cur),x,y,z,val),ps(cur);
return;
}
int qy(int cur,int x,int y,int z){
if(!cur)return -inf;
if(tr[cur].mn[0]>x||tr[cur].mn[1]>y||tr[cur].mn[2]>z)return -inf;
if(tr[cur].mx[0]<=x&&tr[cur].mx[1]<=y&&tr[cur].mx[2]<=z)return dat(cur);
int res=-inf;
if(tr[cur].x[0]<=x&&tr[cur].x[1]<=y&&tr[cur].x[2]<=z)res=tr[cur].v;
return max(res,max(qy(ls(cur),x,y,z),qy(rs(cur),x,y,z)));
}
signed main(){
cin>>n,K=3;
for(int i=1;i<=n;i++)a[i]=i;
for(int i=1;i<=n;i++)
cin>>w[i].x[0]>>w[i].x[1]>>w[i].x[2]>>w[i].x[3];
sort(w+1,w+n+1,cmp);
for(int i=1;i<=n;i++)
for(int j=0;j<3;j++)tr[i].x[j]=w[i].x[j+1];
bd(rt,1,n,0);
for(int i=1;i<=n;i++){
int res=qy(rt,w[i].x[1],w[i].x[2],w[i].x[3]);
f[i]=max(f[i],max(0ll,res)+1);
upd(rt,w[i].x[1],w[i].x[2],w[i].x[3],f[i]);
ans=max(ans,f[i]);
}
cout<<ans;
return 0;
}
:::
:::info[双倍经验]{open} P5621 [DBOI2019] 德丽莎世界第一可爱
稍微有点小改,而且我被卡常了 O_o。 :::
eg.9 P6349 [PA 2011] Kangaroos
难一点的题。
:::info[省流] 前文提到可以在 KDT 上打懒标记维护区间操作,这道题就是 plus 版懒标记。大概就是前文维护的类似线段树的东西变成了吉司机线段树。 :::
套路化的,把区间
那么如何判定
黄色部分是无交的(左端点大于
那么现在需要维护区间加、区间赋值、历史最大值。直接可以 P4314 CPU 监控,维护当前最大值、历史最大值、当前加法标记、历史最大加法标记、当前赋值标记、历史最大赋值标记这些东西即可。
细节什么的看代码注释。
:::success[代码]
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2e5+5,inf=1e18;
int n,m,now,rt,qx[N],qy[N],a[N];
struct kdt{
int ls,rs;
int mx[2],mn[2],x[2];
int dat,his,ad,mx_ad,tg,mx_tg;//含义如文章所示
#define ls(x) tr[x].ls
#define rs(x) tr[x].rs
#define dat(x) tr[x].dat
#define his(x) tr[x].his
}tr[N];
bool cmp(int x,int y){return tr[x].x[now]<tr[y].x[now];}
void chkmin(int &x,int y){if(x>y)x=y;return;}
void chkmax(int &x,int y){if(x<y)x=y;return;}
void ps(int cur){
for(int i=0;i<2;i++){
tr[cur].mx[i]=tr[cur].mn[i]=tr[cur].x[i];
if(ls(cur))
chkmax(tr[cur].mx[i],tr[ls(cur)].mx[i]),
chkmin(tr[cur].mn[i],tr[ls(cur)].mn[i]);
if(rs(cur))
chkmax(tr[cur].mx[i],tr[rs(cur)].mx[i]),
chkmin(tr[cur].mn[i],tr[rs(cur)].mn[i]);
}
return;
}
void bd(int &cur,int l,int r,int k){
if(l>r)return cur=0,void();
int mid=l+r>>1;
now=k,nth_element(a+l,a+mid,a+r+1,cmp),cur=a[mid];
tr[cur].mx_tg=tr[cur].tg=-inf;
bd(ls(cur),l,mid-1,k^1),bd(rs(cur),mid+1,r,k^1),ps(cur);
return;
}
//主要内容是讲一下线段树的实现,kdt实现是基本模板
//对于加法操作,简单来说就是:
//对于没被赋过值的,加在加法标记上
//否则算作赋值的一部分,加在赋值标记上
void add(int cur,int val,int mx_val){//加
chkmax(his(cur),dat(cur)+mx_val),dat(cur)+=val;//更新当前最大值和历史最大值
if(tr[cur].tg==-inf)chkmax(tr[cur].mx_ad,tr[cur].ad+mx_val),tr[cur].ad+=val;
//如果当前没有被打上赋值标记,就将加法打在加法标记上
else chkmax(tr[cur].mx_tg,tr[cur].tg+mx_val),tr[cur].tg+=val;
//否则打在赋值标记上
return;
}
void tag(int cur,int val,int mx_val){//赋值
tr[cur].ad=0,dat(cur)=tr[cur].tg=val;//清空加法标记并赋值赋值标记
chkmax(his(cur),mx_val),chkmax(tr[cur].mx_tg,mx_val);//更新历史最大值和历史最大赋值标记
return;
}
void pd(int cur){//下传标记
//先传加再传赋值,这样加法的打在赋值标记上,再一同赋值
if(tr[cur].ad||tr[cur].mx_ad){
if(ls(cur))add(ls(cur),tr[cur].ad,tr[cur].mx_ad);
if(rs(cur))add(rs(cur),tr[cur].ad,tr[cur].mx_ad);
tr[cur].ad=tr[cur].mx_ad=0;
}
if(tr[cur].tg!=-inf){
if(ls(cur))tag(ls(cur),tr[cur].tg,tr[cur].mx_tg);
if(rs(cur))tag(rs(cur),tr[cur].tg,tr[cur].mx_tg);
tr[cur].tg=tr[cur].mx_tg=-inf;
}
return;
}
void upd(int cur,int l,int r){
if(!cur)return;
if(tr[cur].mx[1]<l||tr[cur].mn[0]>r)return tag(cur,0,0),void();//文章所说的不交情况,赋值为0
if(l<=tr[cur].mn[1]&&tr[cur].mx[0]<=r)return add(cur,1,1),void();//被相交包含,全部+1
//否则是部分相交部分不相交
pd(cur);
if(tr[cur].x[1]<l||tr[cur].x[0]>r)chkmax(his(cur),dat(cur)),dat(cur)=0;
else dat(cur)++,chkmax(his(cur),dat(cur));
//先判定根是否相交并打上对应标记,然后递归处理
upd(ls(cur),l,r),upd(rs(cur),l,r);
return;
}
void calc(int cur){//最后计算答案
if(!cur)return;
//别忘了下放标记
pd(cur),calc(ls(cur)),calc(rs(cur));
return;
}
signed main(){
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>qx[i]>>qy[i];
for(int i=1;i<=m;i++)cin>>tr[i].x[0]>>tr[i].x[1],a[i]=i;
bd(rt,1,m,0);
for(int i=1;i<=n;i++)upd(rt,qx[i],qy[i]);
calc(rt);
for(int i=1;i<=m;i++)cout<<tr[i].his<<"\n";
return 0;
}
:::
Part.4 后记
完结撒花!
本人才疏学浅,若有错误请不吝赐教qwq
如果有不解可以私信与我交流。