Stage4 - Needle

· · 题解

 作为验题人,有理由来写一篇题解。

 首先肯定要先将这些刺进行排序,于是按照直觉以 x 为第一关键字从小到大, y 为第二关键字从大到小排序。然后再来分析题目:令 s_1(x_1,y_1), s_2(x_2,y_2), s_3(x_3,y_3)。显然,向下的刺肯定是没用的。我们从 s_2 进行分析,发现对于每个 s_2,实际上是要找到其左下角有多少个满足条件的 s_1 且每个 s_1 有多少个 s_3 可以与其匹配。可以知道 s_1 一定在 s2 左下角, s_3 一定在 s_1 右下角,那么可以分别求解,对于每个 s_2 直接求其左下方的矩形 (x_2-d,-\infty ),(x_2-1,y_2-1) 中有多少个 s_1 即可,而对于每个 s_1 的贡献即为其右下方的矩形 (x_2,y_1-1),(\infty ,-\infty )s_3 的数量。每次暴力查询的复杂度肯定是无法接受的,那么我们需要寻找一下优化。可以发现如果按照之前的方法排序,再进行遍历的话,每一行 s_1 的数量是递增的, 每一行 s_3 的数量是递减的,而且查询矩形 (x_2-d,-\infty ),(x_2-1,y_2-1) 的右边界只会一直向右, (x_2,y_1-1),(\infty,-\infty) 的左边界也只会一直向右,那么我们就可以将查询矩形 x 轴的部分压缩掉,每次只用查询 (-\infty ,y_2-1][y_1-1,\infty ) 这两段区间即可。

 于是可以想到线段树。使用两棵线段树,第一棵用来维护每一行 s_3 个数的变化,第二棵线段树用来维护每一行 s_1 数量的变化以及每一行 s_1 产生的贡献。每一行 s_1 产生的贡献即为 该行 s_1 的个数 \times 区间 (-\infty,y1_1]s_3 的个数。但是,随着遍历时 s_2 的横坐标增大,会有很多 s_3 不再满足 x_2 \leq x_3 的条件,不能算作答案。如何规避这一点呢,不难发现,我们每遍历到一个向上的刺,这个刺就再也不可能作为答案了,因为剩余所有的 s_2 都不会满足 x_2 \leq x_3。那么每当我们遇到 s_3,就需要减去其之前产生的贡献。考虑可以与其匹配的 s_1 , 都在其左上角方,也就是对其左上方每一行中每一个 s_1 都会产生一个贡献,由于我们第二棵线段树维护的是每一行 s_1 产生的贡献,那么每一行新的贡献就相当于是 该行原贡献 - 该行 s_1 的数量。那么每减去一个 s_3 的贡献,就相当于在线段树上做了一遍区间操作而已。

 最后一个限制就是 x_2-x_1\leq d,考虑如何维护。很简单,遍历的同时将 s_1 塞到一个队列里,每次碰到 s_2,就拿队头判断,如果队头的 s_1 不满足条件,就减去其贡献,把它弹出,再继续比较,直到队头满足条件即可。单个 s_1 的贡献也很好求,相当于在第一棵线段树中做区间求和,然后在到第二棵线段树上进行单点修改即可。

 如此看来,遍历所有点的复杂度为 O(n),每个点上会进行的操作皆是在线段树上进行,都是 O(\log{n}) 的复杂度,所以总复杂度为 O(n \log{n}),虽然常数巨大,但也是能跑过的,其余不懂的就直接看代码吧。

#include<bits/stdc++.h>
#define fir first
#define sed second
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
const int N=3e5+10;
int T,id,n,d;
struct node
{
    int xx,yy;
    int type;
    bool operator<(const node aa)const
    {
        if(xx==aa.xx) return yy>aa.yy;
        else return xx<aa.xx;
    }
}a[N];
struct Node
{
    int l,r;
    int si,add_si;
    LL val,add_val;
    int tag;
};
int b1[N],b2[N],c[N];
int tot;
map<int,int> mp1,mp2;
int cnt1,cnt2;
LL ans;
struct Tree
{
    Node tr[N*4];
    void push_up(int pos)
    {
        tr[pos].si=tr[pos<<1].si+tr[pos<<1|1].si;
        tr[pos].val=tr[pos<<1].val+tr[pos<<1|1].val;
    }
    void push_del(int pos,int val)
    {
        tr[pos].val=tr[pos].val-(LL)tr[pos].si*val;
        tr[pos].tag+=val;
    }
    void push_si(int pos,int val)
    {
        tr[pos].si+=val;
        tr[pos].add_si+=val;
    }
    void push_val(int pos,LL val)
    {
        tr[pos].val+=val;
        tr[pos].add_val+=val;
    }
    void push_down(int pos)
    {
        if(tr[pos].tag)
        {
            push_del(pos<<1,tr[pos].tag);
            push_del(pos<<1|1,tr[pos].tag);
            tr[pos].tag=0;
        }
        if(tr[pos].add_si)
        {
            push_si(pos<<1,tr[pos].add_si);
            push_si(pos<<1|1,tr[pos].add_si);
            tr[pos].add_si=0;
        }
        if(tr[pos].add_val)
        {
            push_val(pos<<1,tr[pos].add_val);
            push_val(pos<<1|1,tr[pos].add_val);
            tr[pos].add_val=0;
        }
    }
    void build(int pos,int l,int r)
    {
        if(l>r) return ;
        tr[pos].l=l,tr[pos].r=r;
        tr[pos].si=tr[pos].val=0;
        tr[pos].tag=0;
        if(l==r)
        {
            tr[pos].si=c[l];
            return ;
        }
        int mid=l+r>>1;
        build(pos<<1,l,mid);
        build(pos<<1|1,mid+1,r);
        push_up(pos);
    }
    int query_si(int pos,int l,int r)
    {
        if(tr[pos].l>=l&&tr[pos].r<=r)
            return tr[pos].si;
        push_down(pos);
        int mid=tr[pos].l+tr[pos].r>>1;
        int sum=0;
        if(mid>=l) sum+=query_si(pos<<1,l,r);
        if(mid<r) sum+=query_si(pos<<1|1,l,r);
        return sum; 
    }
    LL query_val(int pos,int l,int r)
    {
        if(tr[pos].l>=l&&tr[pos].r<=r){
            return tr[pos].val;
        }
        push_down(pos);
        int mid=tr[pos].l+tr[pos].r>>1;
        LL sum=0;
        if(mid>=l) sum+=query_val(pos<<1,l,r);
        if(mid<r) sum+=query_val(pos<<1|1,l,r);
        return sum;
    }
    void modify_si(int pos,int xx,int val)
    {
        if(tr[pos].l==xx&&tr[pos].r==xx){
            tr[pos].si+=val;
            return ;
        }
        push_down(pos);
        int mid=tr[pos].l+tr[pos].r>>1;
        if(mid>=xx) modify_si(pos<<1,xx,val);
        else modify_si(pos<<1|1,xx,val);
        push_up(pos);
    }
    void modify_val(int pos,int xx,LL val)
    {
        if(tr[pos].l==xx&&tr[pos].r==xx){
            tr[pos].val+=val;
            return ;
        }
        push_down(pos);
        int mid=tr[pos].l+tr[pos].r>>1;
        if(mid>=xx) modify_val(pos<<1,xx,val);
        else modify_val(pos<<1|1,xx,val);
        push_up(pos);
    }
    void modify(int pos,int l,int r)
    {
        if(tr[pos].l>=l&&tr[pos].r<=r){
            push_del(pos,1);
            return ;
        }
        push_down(pos);
        int mid=tr[pos].l+tr[pos].r>>1;
        if(mid>=l) modify(pos<<1,l,r);
        if(mid<r) modify(pos<<1|1,l,r);
        push_up(pos);
    }
}tr1,tr2;
deque<PII > q;
int main()
{
    scanf("%d %d",&T,&id);
    while(T--)
    {
        tot=0,cnt1=cnt2=0;
        ans=0;
        mp1.clear();
        mp2.clear();
        while(!q.empty()) q.pop_front();
        scanf("%d %d",&n,&d);
        for(int i=1;i<=n;i++)
        {
            int xx,yy;
            char op[3];
            scanf("%d %d",&xx,&yy);
            scanf("%s",op);
            if(op[0]=='R'){
                a[++tot]={xx,yy,1};
                b1[tot]=xx,b2[tot]=yy;
            }
            else if(op[0]=='L'){
                a[++tot]={xx,yy,2};
                b1[tot]=xx,b2[tot]=yy;
            }
            else if(op[0]=='U'){
                a[++tot]={xx,yy,3};
                b1[tot]=xx,b2[tot]=yy;
            }
        }
        sort(a+1,a+1+tot);
        sort(b1+1,b1+1+tot);
        sort(b2+1,b2+1+tot);
        for(int i=1;i<=tot;i++){
            if(i==1||b1[i]!=b1[i-1]) mp1[b1[i]]=++cnt1;
            if(i==1||b2[i]!=b2[i-1]) mp2[b2[i]]=++cnt2;
        }
        for(int i=1;i<=cnt2;i++) c[i]=0;
        tr1.build(1,1,cnt2);
        for(int i=1;i<=tot;i++){
            if(a[i].type==3) c[mp2[a[i].yy]]++;
        }
        tr2.build(1,1,cnt2);
        for(int i=1;i<=tot;i++){
            if(a[i].type==1){
                int u=mp2[a[i].yy];
                if(u==1) continue ;
                int sum=tr2.query_si(1,1,u-1);
                tr1.modify_val(1,u,(LL)sum);
                tr1.modify_si(1,u,1);
                q.push_back(make_pair(a[i].xx,a[i].yy));
            }else if(a[i].type==2){
                while(!q.empty())
                {
                    if(a[i].xx-q.front().fir>d){
                        int u=mp2[q.front().sed];
                        if(u>1){
                            int sum=tr2.query_si(1,1,u-1);
                            tr1.modify_val(1,u,(LL)sum*-1LL);
                            tr1.modify_si(1,u,-1);
                        }
                        q.pop_front();
                    }else break ;
                }
                int u=mp2[a[i].yy];
                if(u==1) continue ;
                LL sum=tr1.query_val(1,1,u-1);
                ans+=sum;
            }else{
                int u=mp2[a[i].yy];
                tr2.modify_si(1,u,-1);
                if(u==cnt2) continue ;
                tr1.modify(1,u+1,cnt2);
            }
        }
        printf("%lld\n",ans);
    }
    return 0;
}