P3769题解

· · 题解

思路:

这是一道的四维 LIS 问题,类比二维 LIS,我们可以先将第一位排序。接着使用 DP 求解。

为了快速找出另外三维都比它小的数,我们可以用KD树进行维护。(KD 树擅长维护多维信息

实现

排序

需要注意的是,这里不能仅仅考虑第一维,而是要考虑所有,只是优先级不同。

例:对于输入:

2
1 0 0 1
1 0 0 0

如果只考虑第1维,那么其顺序不会发生改变。由于第前一个数据不能在第后一个数据的基础上更新,最后输出将为 1,但实际应该是 2。 代码:

bool cmp(Point a,Point b){
    if(a.dim[0]!=b.dim[0])return a.dim[0]<b.dim[0];
    if(a.dim[1]!=b.dim[1])return a.dim[1]<b.dim[1];
    if(a.dim[2]!=b.dim[2])return a.dim[2]<b.dim[2];
    return a.dim[3]<b.dim[3];
}
sort(ele+1,ele+1+n,cmp);//主函数中

插入元素

由于 KD 树无法旋转或分裂,这里使用替罪羊树进行维护平衡。

int bin[szl],top;//注意写垃圾回收
int New(){
    return top?bin[top--]:++tot;
}
void Slap(int u){
    if(!u) return;
    Slap(ls(u));
    order[++cnt] = kdt[u].poi;
    bin[++top]=u;
    Slap(rs(u));
    return;
}
int Build(int l,int r,int d){
    if(l>r)return 0;
    d=d%(K-1)+1;
    int p=New();
    int mid=l+r>>1;
    cmpd=d;
    nth_element(order+l,order+mid,order+r+1);
    poi(p)=order[mid];
    ls(p)=Build(l,mid-1,d+1);
    rs(p)=Build(mid+1,r,d+1);
    Up(p);
    return p;
}
bool imblce(int p){
    return siz(ls(p))>alpha*siz(p)||siz(rs(p))>alpha*siz(p);
}
void Insert(int &p,Point now,int d){
    d=d%(K-1)+1;
    if(!p){
        p=New();
        ls(p)=rs(p)=0,poi(p)=now;
        Up(p);
        return;
    }
    if(now.dim[d]<=poi(p).dim[d])Insert(ls(p),now,d+1);
    else                         Insert(rs(p),now,d+1);
    Up(p);
    if(imblce(p)){
        cnt=0;
        Slap(p);
        p=Build(1,siz(p),d);
    }
}

查询

分为全部符合,全部不符合,和部分符合。

int Query(int p,int x,int y,int z){//找最大值 
    if(!p)return 0;
    if(mx(p)[1]<=x&&mx(p)[2]<=y&&mx(p)[3]<=z)return mxval(p);//完全符合
    if(mn(p)[1]>x||mn(p)[2]>y||mn(p)[3]>z)return 0;//完全不符合
    //部分符合,向下寻找 
    bool d=poi(p).dim[1]<=x&&poi(p).dim[2]<=y&&poi(p).dim[3]<=z;
    return max(max(Query(ls(p),x,y,z),Query(rs(p),x,y,z)),d*poi(p).val);
}

最终代码:

//KD-Tree优化DP 
#include<bits/stdc++.h>
using namespace std;
const int szl=5e5+5,K=4;//树中用1~K-1维 
const double alpha=0.75;
int n,cmpd,tot,cnt,root;
struct Point{
    int dim[K],val;
    void Read(){
        for(int i=0;i<K;i++)cin>>dim[i];
        val=0;
        return;
    }
    bool operator<(const Point &tmp)const{
        return dim[cmpd]<tmp.dim[cmpd];
    }
}ele[szl],order[szl];
bool cmp(Point a,Point b){
    if(a.dim[0]!=b.dim[0])return a.dim[0]<b.dim[0];
    if(a.dim[1]!=b.dim[1])return a.dim[1]<b.dim[1];
    if(a.dim[2]!=b.dim[2])return a.dim[2]<b.dim[2];
    return a.dim[3]<b.dim[3];
}
int bin[szl],top;
int New(){
    return top?bin[top--]:++tot;
}
struct KdTr{
    int mn[K],mx[K],ls,rs,mxval,siz;
    Point poi; 
    //用于减少码量
    #define ls(x) kdt[x].ls
    #define rs(x) kdt[x].rs
    #define mn(x) kdt[x].mn
    #define mx(x) kdt[x].mx
    #define mxval(x) kdt[x].mxval
    #define siz(x) kdt[x].siz
    #define poi(x) kdt[x].poi
}kdt[szl];
void Up(int p){
    for(int i=1;i<K;i++){
        mn(p)[i]=mx(p)[i]=poi(p).dim[i];
        if(ls(p)){
            mn(p)[i]=min(mn(p)[i],mn(ls(p))[i]);
            mx(p)[i]=max(mx(p)[i],mx(ls(p))[i]);
        }
        if(rs(p)){
            mn(p)[i]=min(mn(p)[i],mn(rs(p))[i]);
            mx(p)[i]=max(mx(p)[i],mx(rs(p))[i]);
        }
    }
    mxval(p)=max(max(mxval(ls(p)),mxval(rs(p))),poi(p).val);
    siz(p)=siz(ls(p))+siz(rs(p))+1;
    return;
}
void Slap(int u){
    if(!u) return;
    Slap(ls(u));
    order[++cnt] = kdt[u].poi;
    bin[++top]=u;
    Slap(rs(u));
    return;
}
int Build(int l,int r,int d){
    if(l>r)return 0;
    d=d%(K-1)+1;
    int p=New();
    int mid=l+r>>1;
    cmpd=d;
    nth_element(order+l,order+mid,order+r+1);
    poi(p)=order[mid];
    ls(p)=Build(l,mid-1,d+1);
    rs(p)=Build(mid+1,r,d+1);
    Up(p);
    return p;
}
bool imblce(int p){
    return siz(ls(p))>alpha*siz(p)||siz(rs(p))>alpha*siz(p);
}
void Insert(int &p,Point now,int d){
    d=d%(K-1)+1;
    if(!p){
        p=New();
        ls(p)=rs(p)=0,poi(p)=now;
        Up(p);
        return;
    }
    if(now.dim[d]<=poi(p).dim[d])Insert(ls(p),now,d+1);
    else                         Insert(rs(p),now,d+1);
    Up(p);
    if(imblce(p)){
        cnt=0;
        Slap(p);
        p=Build(1,siz(p),d);
    }
}
int ans=0;
int Query(int p,int x,int y,int z){//找最大值 
    if(!p)return 0;
    if(mx(p)[1]<=x&&mx(p)[2]<=y&&mx(p)[3]<=z)return mxval(p);//完全符合
    if(mn(p)[1]>x||mn(p)[2]>y||mn(p)[3]>z)return 0;//完全不符合
    //向下寻找 
    bool d=poi(p).dim[1]<=x&&poi(p).dim[2]<=y&&poi(p).dim[3]<=z;
    return max(max(Query(ls(p),x,y,z),Query(rs(p),x,y,z)),d*poi(p).val);
}

int main(){
    cin>>n;
    for(int i=1;i<=n;i++)ele[i].Read();
    sort(ele+1,ele+1+n,cmp);
    for(int i=1;i<=n;i++){
        ele[i].val=Query(root,ele[i].dim[1],ele[i].dim[2],ele[i].dim[3])+1;
        Insert(root,ele[i],1);
    }
    cout<<mxval(root);
    return 0;
}