题解:P10637 BZOJ4262 Sum

· · 题解

一个不用 beats 的 DS 做法。

其实也本质一样的。具体就是注意到 P3246 还有一种不用 beats 的单调栈 + 线段树做法。

考虑只有 $\max$,记 $w(l,r)=\max\limits_{i=l}^ra_i$。 记 $$S(l,r,x)=\sum\limits_{i=l}^r\sum\limits_{j=1}^xw(i,j)$$ 差分得到: $$ \begin{aligned} \sum\limits_{i=l_1}^{r_1}\sum\limits_{j=l_2}^{r_2}w(i,j) &=\sum\limits_{i=l_1}^{r_1}\sum\limits_{j=1}^{r_2}w(i,j)-\sum\limits_{i=l_1}^{r_1}\sum\limits_{j=1}^{l_2-1}w(i,j)\\ &=S(l_1,r_1,r_2)-S(l_1,r_1,l_2-1) \end{aligned} $$ 把询问挂在 $x$ 上,按 $x$ 从小到大扫描线。 则扫到 $x$ 的时候 $S(l,r,x)$ 就是 $w(i,x)$ 历史和的 $[l,r]$ 区间和。 考虑维护 $w(i,x)$,直接做就是 beats 了。进一步可以发现 $w(i,x)$ 就是后缀 $\max$ 数组,有单调性。 所以掏出一个单调栈,把前缀取 $\max$ 的操作变成一个后缀推平,掏出一个区间推平区间历史和线段树维护即可。 【为了让读者看得到下面的内容,原来的代码被删了】 --- upd $2024.6.26

傻了,不需要区间推平。注意到单调栈的存在也维护了颜色段,所以直接每段区间加即可。

其次,这里有一个非常简单且跑得非常快的做法,感谢 @AbsMatt 的提醒。

注意到数据随机,根据经典结论单调栈的大小期望是 O(\log n) 的,所以直接遍历每个栈元素计算答案复杂度是对的。

具体就是,考虑每个栈元素,对历史和贡献分为“加入前”,“加入后删除前”,“删除后”。

“加入前”没有贡献,“删除后”的贡献就是一个固定的区间加,而“加入后删除前”的贡献形如 a_i\Delta t,其中 t 是时间戳。

所以删除一个栈元素的时候,就把它的贡献区间加到一个数据结构上,计算区间答案的时候“删除后”的贡献直接在数据结构上查询区间和即可。

对于计算时“加入后删除前”的元素,可以直接遍历每一个,贡献就是 a_i\Delta t\times lenlen 是该栈元素对应的区间和询问区间的交的长度。

这个数据结构可以是线段树,也可以是树状数组。可以做到比较小常数的 O(n\log n)

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
const int N=1e5+5;
int n=1e5,m,a[N];

ll c1[N],c2[N];
inline void upd(int l,int r,ll k){      // 区间加区间和的树状数组
    ll k2=k*(l-1);
    for(;l<=n;l+=l&-l){c1[l]+=k;c2[l]+=k2;}
    k2=k*r++;
    for(;r<=n;r+=r&-r){c1[r]-=k;c2[r]-=k2;}
}
inline ll qry(int p){
    ll s1=0,s2=0;
    for(int i=p;i;i^=i&-i){s1+=c1[i];s2+=c2[i];}
    return s1*p-s2;
}
vector<tuple<int,int,int,bool>>q[N];ll ans[N];
#define eb emplace_back
int top,stk[N],tim[N];
inline void solve(){
#define mem(a) memset(a,0,sizeof a)
    mem(c1);mem(c2);top=0;
    for(int x=1;x<=n;x++){
        for(;top&&a[x]>=a[stk[top]];top--)
            upd(stk[top-1]+1,stk[top],(ll)(x-tim[top])*a[stk[top]]);
        stk[++top]=x;tim[top]=x;
        for(auto[l,r,i,o]:q[x]){
            ll res=qry(r)-qry(l-1);
            for(int t=1;t<=top;t++){
                int len=min(stk[t],r)-max(stk[t-1],l-1);
                if(len>0)res+=(x-tim[t]+1ll)*a[stk[t]]*len;
            }
            ans[i]+=o?-res:res;
        }
    }
}
int main(){
    const int mod=1e9;
    int fst=1023,sec=1025;
    for(int i=1;i<=n;i++){
        a[i]=fst^sec;
        fst=fst*1023ll%mod;
        sec=sec*1025ll%mod;
    }
    scanf("%d",&m);
    for(int i=1,l1,r1,l2,r2;i<=m;i++){
        scanf("%d%d%d%d",&l1,&r1,&l2,&r2);
        q[r2].eb(l1,r1,i,0);
        q[l2-1].eb(l1,r1,i,1);
    }
    solve();
    for(int i=1;i<=n;i++)a[i]=-a[i];
    solve();
    for(int i=1;i<=m;i++)printf("%lld\n",ans[i]);
    return 0;
}