题解 P4475 【巧克力王国】

· · 题解

已到很不错的KD_tree入门题目

首先得知道什么是KD_tree

首先KD_tree是一颗二叉搜索树 树中储存了k维的数据信息 构树相当于是对k维空间进行划分的过程 所以每个节点就有了对应的k维空间的一个范围

比如当k=1时 这时候的KD_tree就是我们所熟悉的线段树 每个节点就对应了一维的一个区间

这里有不同的是 KD_tree的每个节点都储存了信息 类似伸展树(splay) 而线段树仅有叶子节点储存信息

这里我用结构体来呈现一个 KD_tree

一个节点含有的信息有

1. d[k] 每个维度的值
2. mx[2],mn[2] 该树及以下节点每个维度的max和min
3.lc,rc 左儿子和右儿子
根据题目需要 这里加入变量val(权值) 维护一个sum(权值和)

接下来是KD_tree的构造方法

我们把第i层的节点按照第 i%(维度数量)维度的优先级 取中位数(就是找到一个划分节点mid) 然后根据mid划分左右儿子 如此循环下去直到叶子节点

这是最常见的划分方法 但是容易被一些数据给卡住,见这篇文章 (以上解说或多或少的都借鉴了这篇文章)

然后是针对此题的查询

如果该节点的mx mn全部满足a x+b y < c

那么该节点一下的节点都满足 直接返回sum

否则就只能拆开该节点和左右儿子 递归下去

代码部分

/*
简单的入门KD_tree

首先需要一个专门用来排序的数组 dat
    其内容一般包括:
    1.每个维度的值
    2.该树及一下部分每个维度的max和min
    3.左儿子和右儿子
    4.权值之类的,我们需要维护的 (这里维护了一个sum
KD_tree本质是一个二叉搜索树 
我们把第i层的节点按照第 i%(维度数量)唯独的优先级 取中位数(就是找到一个划分节点mid
然后根据mid划分左右儿子 如此循环下去直到叶子节点

该题的思路是 如果该节点的mx mn全部满足a*x+b*y<c
那么该节点一下的节点都满足 直接返回sum
否则就只能拆开该节点和左右儿子 递归下去
*/
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
const int N=5e5+50;
int n,now,m,rt;
ll a,b,c;
struct data
{
    int d[2],mx[2],mn[2],lc,rc,val;
    ll sum;
    friend bool operator < (data a,data b)
        {return a.d[now]<b.d[now];}
}dat[N],t[N];
void getmax(int&a,int b){if(a<b)a=b;}
void getmin(int&a,int b){if(a>b)a=b;}
void pushup(int x)
{
    int lc=t[x].lc,rc=t[x].rc;
    for(int i=0;i<2;i++)
    {
        t[x].mn[i]=t[x].mx[i]=t[x].d[i];
        if(lc)  getmin(t[x].mn[i],t[lc].mn[i]),
            getmax(t[x].mx[i],t[lc].mx[i]);
        if(rc)  getmin(t[x].mn[i],t[rc].mn[i]),
            getmax(t[x].mx[i],t[rc].mx[i]);
    }
    t[x].sum=t[lc].sum+t[rc].sum+t[x].val;
}

int build(int l,int r,int pl)
{
    now=pl; int mid=(l+r)>>1;
    nth_element(dat+l,dat+mid,dat+r+1);
    t[mid]=dat[mid];
    if(l<mid) t[mid].lc=build(l,mid-1,!pl);
    if(r>mid) t[mid].rc=build(mid+1,r,!pl);
    pushup(mid); return mid;
}
bool check(ll x,ll y) {return x*a+y*b<c;}
ll query(int x)
{
    int tot=0;
    tot+=check(t[x].mx[0],t[x].mx[1]);
    tot+=check(t[x].mn[0],t[x].mx[1]);
    tot+=check(t[x].mx[0],t[x].mn[1]);
    tot+=check(t[x].mn[0],t[x].mn[1]);
    if(tot==4) return t[x].sum; // 都满足
    if(tot==0) return 0; // 都不满足
    ll res=0;
    if(check(t[x].d[0],t[x].d[1])) res+=t[x].val;
    if(t[x].lc) res+=query(t[x].lc);
    if(t[x].rc) res+=query(t[x].rc);
    return res;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) 
        scanf("%d%d%d",&dat[i].d[0],&dat[i].d[1],&dat[i].val);
    rt=build(1,n,0); while(m--)
    {
        scanf("%lld%lld%lld",&a,&b,&c);
        printf("%lld\n",query(rt));
    }
    return 0;
}

再推荐一篇好的文章

至于KD_tree的邻值查询 (不是与本题无关吗,我也不会)