题解 P4475 【巧克力王国】
已到很不错的KD_tree入门题目
首先得知道什么是KD_tree
首先KD_tree是一颗二叉搜索树 树中储存了k维的数据信息 构树相当于是对k维空间进行划分的过程 所以每个节点就有了对应的k维空间的一个范围
比如当k=1时 这时候的KD_tree就是我们所熟悉的线段树 每个节点就对应了一维的一个区间
这里有不同的是 KD_tree的每个节点都储存了信息 类似伸展树(splay) 而线段树仅有叶子节点储存信息
这里我用结构体来呈现一个 KD_tree
一个节点含有的信息有
1. d[k] 每个维度的值
2. mx[2],mn[2] 该树及以下节点每个维度的max和min
3.lc,rc 左儿子和右儿子
根据题目需要 这里加入变量val(权值) 维护一个sum(权值和)
接下来是KD_tree的构造方法
我们把第i层的节点按照第 i%(维度数量)维度的优先级 取中位数(就是找到一个划分节点mid) 然后根据mid划分左右儿子 如此循环下去直到叶子节点
这是最常见的划分方法 但是容易被一些数据给卡住,见这篇文章 (以上解说或多或少的都借鉴了这篇文章)
然后是针对此题的查询
如果该节点的mx mn全部满足a x+b y < c
那么该节点一下的节点都满足 直接返回sum
否则就只能拆开该节点和左右儿子 递归下去
代码部分
/*
简单的入门KD_tree
首先需要一个专门用来排序的数组 dat
其内容一般包括:
1.每个维度的值
2.该树及一下部分每个维度的max和min
3.左儿子和右儿子
4.权值之类的,我们需要维护的 (这里维护了一个sum
KD_tree本质是一个二叉搜索树
我们把第i层的节点按照第 i%(维度数量)唯独的优先级 取中位数(就是找到一个划分节点mid
然后根据mid划分左右儿子 如此循环下去直到叶子节点
该题的思路是 如果该节点的mx mn全部满足a*x+b*y<c
那么该节点一下的节点都满足 直接返回sum
否则就只能拆开该节点和左右儿子 递归下去
*/
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
const int N=5e5+50;
int n,now,m,rt;
ll a,b,c;
struct data
{
int d[2],mx[2],mn[2],lc,rc,val;
ll sum;
friend bool operator < (data a,data b)
{return a.d[now]<b.d[now];}
}dat[N],t[N];
void getmax(int&a,int b){if(a<b)a=b;}
void getmin(int&a,int b){if(a>b)a=b;}
void pushup(int x)
{
int lc=t[x].lc,rc=t[x].rc;
for(int i=0;i<2;i++)
{
t[x].mn[i]=t[x].mx[i]=t[x].d[i];
if(lc) getmin(t[x].mn[i],t[lc].mn[i]),
getmax(t[x].mx[i],t[lc].mx[i]);
if(rc) getmin(t[x].mn[i],t[rc].mn[i]),
getmax(t[x].mx[i],t[rc].mx[i]);
}
t[x].sum=t[lc].sum+t[rc].sum+t[x].val;
}
int build(int l,int r,int pl)
{
now=pl; int mid=(l+r)>>1;
nth_element(dat+l,dat+mid,dat+r+1);
t[mid]=dat[mid];
if(l<mid) t[mid].lc=build(l,mid-1,!pl);
if(r>mid) t[mid].rc=build(mid+1,r,!pl);
pushup(mid); return mid;
}
bool check(ll x,ll y) {return x*a+y*b<c;}
ll query(int x)
{
int tot=0;
tot+=check(t[x].mx[0],t[x].mx[1]);
tot+=check(t[x].mn[0],t[x].mx[1]);
tot+=check(t[x].mx[0],t[x].mn[1]);
tot+=check(t[x].mn[0],t[x].mn[1]);
if(tot==4) return t[x].sum; // 都满足
if(tot==0) return 0; // 都不满足
ll res=0;
if(check(t[x].d[0],t[x].d[1])) res+=t[x].val;
if(t[x].lc) res+=query(t[x].lc);
if(t[x].rc) res+=query(t[x].rc);
return res;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d%d%d",&dat[i].d[0],&dat[i].d[1],&dat[i].val);
rt=build(1,n,0); while(m--)
{
scanf("%lld%lld%lld",&a,&b,&c);
printf("%lld\n",query(rt));
}
return 0;
}
再推荐一篇好的文章
至于KD_tree的邻值查询 (不是与本题无关吗,我也不会)