题解 P6881 【[JOI 2020 Final] 火事】

· · 题解

数据结构好题。

我们考虑画出每一时刻所有位置的值:

这其中,我们可以找到最标准的情形:里面的元素 \color{cyan}{7}。可以看到,随着时间增加,它呈现一个平行四边形。

那么这个平行四边形由上面构成呢?我们考虑用单调栈找出每个位置向左向右第一个比它大的位置 l_i,r_i。则我们会发现,平行四边形的四个顶点是 (i,0),(r_i-1,r_i-i-1),(r_i-1,r_i-l_i-2),(i,i-l_i-1)

我们考虑此时的询问,在坐标系上,就成为了一段水平线段。考虑差分一下,就变成了两条沿 x 轴负方向的的射线。

考虑把一个平行四边形拆成三个三角形:

如图,绿色平行四边形=大三角形-蓝三角形-紫三角形。

然后我们发现,每个三角形的形状都是一样的,并且都有一条腰在 x 轴上。故我们可以用其上方顶点来唯一确定一个三角形。则有大三角形顶点 (r_i-1,r_i-l_i-2),蓝三角形顶点 (i-1,i-l_i-2),紫三角形顶点 (r_i-1,r_i-i-2)

我们考虑一条沿 x 轴负方向的射线与三角形的交线段长度:

如图,如果该射线是绿线(原点在三角形内),可以将其平移至紫线位置,然后就可以拆括号计算了。如果该直线是棕线,则也可以简单计算。

于是我们可以用扫描线维护对于递增的射线原点的 x 坐标,相应的位于其左侧的三角形们和位于其右侧的三角形们,然后分别转移即可。

至于如何转移,可以各通过两棵树状数组,一棵维护系数和,一棵维护带权坐标和,然后分别转移即可。共需四棵树状数组。

时间复杂度 O(n\log n)

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,a[200100],l[200100],r[200100],stk[200100],tp,lim,mil;
ll res[200100];
namespace BIT1{//a BIT for triangular summary.
    ll t1[600100],t2[600100];//t1 for weighted sum, t2 for common sum.
    void ADD(int x,int y){
        x--;
        ll tmp=1ll*y*x;
        for(int i=x+n+1;i<=3*n;i+=i&-i)t1[i]+=tmp,t2[i]+=y;
    }
    ll SUM(int x){
        ll ret=0;
        for(int i=x+n+1;i;i-=i&-i)ret-=t1[i],ret+=t2[i]*x;
        return ret;
    }
}
namespace BIT2{
    ll t1[600100],t2[600100];
    void ADD(int x,int y){
        x++;
        ll tmp=1ll*y*x;
        for(int i=x+n+1;i;i-=i&-i)t1[i]+=tmp,t2[i]+=y;
    }
    ll SUM(int x){
        ll ret=0;
        for(int i=x+n+1;i<=3*n;i+=i&-i)ret+=t1[i],ret-=t2[i]*x;
        return ret;
    } 
}
void ADD1(int x,int y,int z){BIT1::ADD(x-y,z);}
void ADD2(int x,int y,int z){BIT2::ADD(y,z);}
struct Query{
    int x,y,id;
    Query(int u=0,int v=0,int w=0){x=u,y=v,id=w;}
    friend bool operator <(const Query &x,const Query &y){return x.x<y.x;}
}q[400100];
struct Tri{
    int x,y,a;
    Tri(int u=0,int v=0,int w=0){x=u,y=v,a=w;}
    friend bool operator <(const Tri &x,const Tri &y){return x.x<y.x;}
}p[800100];
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
    for(int i=n;i>=1;i--){
        while(tp&&a[i]>a[stk[tp]])l[stk[tp--]]=i;
        stk[++tp]=i;
    }
    while(tp)l[stk[tp]]=stk[tp]-n-1,tp--;
    for(int i=1;i<=n;i++){
        while(tp&&a[stk[tp]]<=a[i])r[stk[tp--]]=i;
        stk[++tp]=i;
    }
    while(tp)r[stk[tp--]]=n+1;
    for(int i=1;i<=n;i++){
        p[++lim]=Tri(i-1,i-l[i]-2,-a[i]);
        p[++lim]=Tri(r[i]-1,r[i]-l[i]-2,a[i]);
        p[++lim]=Tri(r[i]-1,r[i]-i-2,-a[i]);
    }
    for(int i=1;i<=lim;i++)ADD1(p[i].x,p[i].y,p[i].a);
    sort(p+1,p+lim+1);
    for(int i=1,t,l,r;i<=m;i++){
        scanf("%d%d%d",&t,&l,&r);
        q[++mil]=Query(l-1,t,-i);
        q[++mil]=Query(r,t,i);
    }
    sort(q+1,q+mil+1);
    for(int i=1,j=1;i<=mil;i++){
        while(j<=lim&&p[j].x<q[i].x)ADD1(p[j].x,p[j].y,-p[j].a),ADD2(p[j].x,p[j].y,p[j].a),j++;
        int lambda=(q[i].id>0?1:-1);q[i].id=abs(q[i].id);
        res[q[i].id]+=(BIT1::SUM(q[i].x-q[i].y)+BIT2::SUM(q[i].y))*lambda;
    }
    for(int i=1;i<=m;i++)printf("%lld\n",res[i]);
//  for(int i=1;i<=n;i++)printf("%d %d\n",l[i],r[i]);
    return 0;
}