题解:P3769 [CH弱省胡策R2] TATT

· · 题解

这道题让我们求四维偏序,显然有 dp 式:

f_i=\max\limits_{a_j\le a_i,b_j\le b_i,c_j\le c_i,d_j\le d_i}f_j+1

按照 a 排序后,我们可以使用分治等方法优化它,我们还可以采用三维数据结构进行优化,只需要单点修改,立体查询即可,但是我们发现一个问题,树套树套树有着高达好几个 log 的复杂度,即使有 512 MB 也无法通过,时间上也比较紧张。

于是我们想办法优化这个做法,我们发现动态开点线段树每次修改会新建出 \log n 个点,我们不需要这么多点,所以我们在新建节点时记录一下这个点开到了哪里和记录了哪个数,然后直接返回,在查询时把标记的点算上即可。

一种可能的代码实现(最内层树):

struct Node{
    int ls,rs,idx,maxn,cnt;
};

int cnt;
int rt[3200000];
Node tr[3200000];

void update(int &i,int l,int r,int x,int k){
    if (i==0){
        i=++cnt;
        tr[i].idx=x;
        tr[i].cnt=tr[i].maxn=k;
        return;
    }
    tr[i].maxn=max(tr[i].maxn,k);
    if (tr[i].idx==x){
        tr[i].cnt=max(tr[i].cnt,k);
        return;
    }
    if (l==r) return;
    int mid=(l+r)>>1;
    if (mid>=x) update(tr[i].ls,l,mid,x,k);
    else update(tr[i].rs,mid+1,r,x,k);
}

int query(int i,int l,int r,int ql,int qr){
    if (i==0) return 0;
    if (l>=ql and r<=qr) return tr[i].maxn;
    int mid=(l+r)>>1,ans=0;
    if (ql<=tr[i].idx and tr[i].idx<=qr) ans=tr[i].cnt;
    if (mid>=ql) ans=max(ans,query(tr[i].ls,l,mid,ql,qr));
    if (mid+1<=qr) ans=max(ans,query(tr[i].rs,mid+1,r,ql,qr));
    return ans;
}

其中 idx 记录这个区间内新建点的下标,cnt 记录这个点的贡献,maxn 是正常的区间最大值。

但是,有了这个优化后还是不能通过,我们发现不仅最内层树可以这样做,中间的动态开点线段树也可以用同样的操作,用两个线段树分别记录 cntmaxn

这样,我们就可以轻松在只用 64 MB,最大点 515ms 的情况下通过这道题了。

完整代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

struct Seg{
    struct Node{
        int ls,rs,idx,maxn,cnt;
    };

    int cnt;
    int rt[3200000];
    Node tr[3200000];

    void update(int &i,int l,int r,int x,int k){
        if (i==0){
            i=++cnt;
            tr[i].idx=x;
            tr[i].cnt=tr[i].maxn=k;
            return;
        }
        tr[i].maxn=max(tr[i].maxn,k);
        if (tr[i].idx==x){
            tr[i].cnt=max(tr[i].cnt,k);
            return;
        }
        if (l==r) return;
        int mid=(l+r)>>1;
        if (mid>=x) update(tr[i].ls,l,mid,x,k);
        else update(tr[i].rs,mid+1,r,x,k);
    }

    int query(int i,int l,int r,int ql,int qr){
        if (i==0) return 0;
        if (l>=ql and r<=qr) return tr[i].maxn;
        int mid=(l+r)>>1,ans=0;
        if (ql<=tr[i].idx and tr[i].idx<=qr) ans=tr[i].cnt;
        if (mid>=ql) ans=max(ans,query(tr[i].ls,l,mid,ql,qr));
        if (mid+1<=qr) ans=max(ans,query(tr[i].rs,mid+1,r,ql,qr));
        return ans;
    }
}segin1,segin2;

struct Node{
    int ls,rs,idx;
};

struct Point{
    int a,b,c,d;
};

int n,rt[150001],b[150001],cnt,tot;
Node tr[3200000];
Point a[50001];
int f[50001];

void update(int &i,int l,int r,int x,int y,int k){
    if (i==0){
        i=++cnt;
        tr[i].idx=x;
        segin1.update(segin1.rt[i],1,tot,y,k);
        segin2.update(segin2.rt[i],1,tot,y,k);
        return;
    }
    segin1.update(segin1.rt[i],1,tot,y,k);
    if (tr[i].idx==x){
        segin2.update(segin2.rt[i],1,tot,y,k);
        return;
    }
    if (l==r) return;
    int mid=(l+r)>>1;
    if (mid>=x) update(tr[i].ls,l,mid,x,y,k);
    else update(tr[i].rs,mid+1,r,x,y,k);
}

int query(int i,int l,int r,int ql,int qr,int qlin,int qrin){
    if (i==0) return 0;
    if (l>=ql and r<=qr) return segin1.query(segin1.rt[i],1,tot,qlin,qrin);
    int mid=(l+r)>>1,ans=0;
    if (ql<=tr[i].idx and tr[i].idx<=qr) ans=segin2.query(segin2.rt[i],1,tot,qlin,qrin);
    if (mid>=ql) ans=max(ans,query(tr[i].ls,l,mid,ql,qr,qlin,qrin));
    if (mid+1<=qr) ans=max(ans,query(tr[i].rs,mid+1,r,ql,qr,qlin,qrin));
    return ans;
}

inline int lowbit(int x){
    return x&(-x);
}

void add(int x,int y,int z,int c){
    for (;x<=tot;x+=lowbit(x)) update(rt[x],1,tot,y,z,c);
}

int query(int x,int y,int z){
    int ans=0;
    for (;x;x-=lowbit(x)) ans=max(ans,query(rt[x],1,tot,1,y,1,z));
    return ans;
}

bool cmp(Point a,Point b){
    if (a.a==b.a){
        if (a.b==b.b){
            if (a.c==b.c){
                return a.d<b.d;
            }
            return a.c<b.c;
        }
        return a.b<b.b;
    }
    return a.a<b.a;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    cin>>n;
    for (int i=1;i<=n;i++){
        cin>>a[i].a>>a[i].b>>a[i].c>>a[i].d;
        b[++tot]=a[i].b;
        b[++tot]=a[i].c;
        b[++tot]=a[i].d;
    }
    sort(a+1,a+n+1,cmp);
    sort(b+1,b+tot+1);
    tot=unique(b+1,b+tot+1)-b-1;
    for (int i=1;i<=n;i++){
        a[i].b=lower_bound(b+1,b+tot+1,a[i].b)-b;
        a[i].c=lower_bound(b+1,b+tot+1,a[i].c)-b;
        a[i].d=lower_bound(b+1,b+tot+1,a[i].d)-b;
    }
    for (int i=1;i<=n;i++){
        f[i]=query(a[i].b,a[i].c,a[i].d)+1;
        add(a[i].b,a[i].c,a[i].d,f[i]);
    }
    cout<<*max_element(f+1,f+n+1);
    return 0;
}