题解 P6242 【【模板】线段树3】
Update 3/29:修正了一些笔误和可能令人误解地方。
关于线段树维护区间最值操作和区间历史最值的完整方法,请参阅吉如一 2016 年的国家集训队论文 和 我的洛谷日报。这篇题解大部分内容是复制于后者。谢谢大家给我的日报捧场 ^_^
概述
首先假设我们有一个序列
-
区间最值操作是指,给定
l,r,k ,对于所有i\in[l,r] ,将A_i 变成\min(A_i,k) 或\max(A_i,k) 的一种操作。 -
对于序列
A 的某个位置i ,它的历史最值是指A_i 从初始化到当前达到的最大值 / 最小值,或者每一次修改操作后的值之和。我们把这三种分别称作历史最大值、历史最小值和历史版本和。而区间历史最值则是指求序列A 的历史最值的最大值 / 最小值或者代数和等的区间问题。
在这道题中,操作
题解
首先,让我们考虑只有操作
考虑到区间取
举个例子,现在要以
按照上面描述的操作,我们应该沿着红色的边 DFS,最后在红色的节点上更新区间和并打上标记。
这样做的复杂度可以证明是
显然只有暴力 DFS 的过程会影响复杂度,我们考虑一次区间最值操作中哪些节点会被搜索到。发现到达的节点区间中不同的数的个数(下称“值域”)一定会减少(因为至少将最大值与次大值合并了)。线段树每层节点表示的区间的值域一共是
O(n) 的,一共\log 层加起来就是O(n\log n) 。再加上每次操作的复杂度,可以得到复杂度的一个下界为O((n+m)\log n) 。
然后,我们加入区间加减操作(操作
由于区间加减操作,某些节点的值域会增大。一次区间加减最多会影响
目前为止,我们已经解决了测试点
现在我们再考虑只有操作
考虑历史最值标记的合并。当节点
这样做的正确性是显然的,区间和可以用常规方法维护,复杂度为
现在考虑将这两部分合并起来,在这里我们应用之前的思想,把一个区间的元素划分成最大值和非最大值两部分。那么区间最值操作可以转化为对最大值的加减操作,而题目中的区间加减则同时作用于这两类值。于是我们以一个
- 最大值加减标记
- 最大值历史最大的加减标记
- 非最大值加减标记
- 非最大值历史最大的加减标记
前两个标记是只修改最大值的,所以下传时要判断当前节点是否包含区间最大值。总复杂度为
如果有地方不太理解的话,可以看下面的代码:
#include <cstdio>
#include <cctype>
#include <algorithm>
using namespace std;
typedef long long ll;
const int MAXN=550000;
const int INF=2E9+233;
int read()
{
int x=0, w=0; char ch=0;
while (!isdigit(ch)) w|=ch=='-', ch=getchar();
while (isdigit(ch)) x=(x<<3)+(x<<1)+(ch^48), ch=getchar();
return w?-x:x;
}
struct SegmentTree
{
struct Node
{
int l, r;
int mx, mx_, se, cnt; ll sum;
int add1, add1_, add2, add2_;
} tr[4*MAXN];
#define lc (o<<1)
#define rc (o<<1|1)
void pushup(int o)
{
tr[o].sum=tr[lc].sum+tr[rc].sum;
tr[o].mx_=max(tr[lc].mx_, tr[rc].mx_);
if (tr[lc].mx==tr[rc].mx)
{
tr[o].mx=tr[lc].mx;
tr[o].se=max(tr[lc].se, tr[rc].se);
tr[o].cnt=tr[lc].cnt+tr[rc].cnt;
}
else if (tr[lc].mx>tr[rc].mx)
{
tr[o].mx=tr[lc].mx;
tr[o].se=max(tr[lc].se, tr[rc].mx);
tr[o].cnt=tr[lc].cnt;
}
else
{
tr[o].mx=tr[rc].mx;
tr[o].se=max(tr[lc].mx, tr[rc].se);
tr[o].cnt=tr[rc].cnt;
}
}
void update(int o, int k1, int k1_, int k2, int k2_)
{
tr[o].sum+=1ll*k1*tr[o].cnt+1ll*k2*(tr[o].r-tr[o].l+1-tr[o].cnt);
tr[o].mx_=max(tr[o].mx_, tr[o].mx+k1_);
tr[o].add1_=max(tr[o].add1_, tr[o].add1+k1_);
tr[o].mx+=k1, tr[o].add1+=k1;
tr[o].add2_=max(tr[o].add2_, tr[o].add2+k2_);
if (tr[o].se!=-INF) tr[o].se+=k2;
tr[o].add2+=k2;
}
void pushdown(int o)
{
int tmp=max(tr[lc].mx, tr[rc].mx);
if (tr[lc].mx==tmp)
update(lc, tr[o].add1, tr[o].add1_, tr[o].add2, tr[o].add2_);
else update(lc, tr[o].add2, tr[o].add2_, tr[o].add2, tr[o].add2_);
if (tr[rc].mx==tmp)
update(rc, tr[o].add1, tr[o].add1_, tr[o].add2, tr[o].add2_);
else update(rc, tr[o].add2, tr[o].add2_, tr[o].add2, tr[o].add2_);
tr[o].add1=tr[o].add1_=tr[o].add2=tr[o].add2_=0;
}
void build(int o, int l, int r, int* a)
{
tr[o].l=l, tr[o].r=r;
tr[o].add1=tr[o].add1_=tr[o].add2=tr[o].add2_=0;
if (l==r)
{
tr[o].sum=tr[o].mx_=tr[o].mx=a[l];
tr[o].se=-INF, tr[o].cnt=1;
return;
}
int mid=l+r>>1;
build(lc, l, mid, a);
build(rc, mid+1, r, a);
pushup(o);
}
void modify1(int o, int l, int r, int k)
{
if (tr[o].l>r||tr[o].r<l) return;
if (l<=tr[o].l&&tr[o].r<=r)
{ update(o, k, k, k, k); return; }
pushdown(o);
modify1(lc, l, r, k), modify1(rc, l, r, k);
pushup(o);
}
void modify2(int o, int l, int r, int k)
{
if (tr[o].l>r||tr[o].r<l||k>=tr[o].mx) return;
if (l<=tr[o].l&&tr[o].r<=r&&k>tr[o].se)
{ update(o, k-tr[o].mx, k-tr[o].mx, 0, 0); return; }
pushdown(o);
modify2(lc, l, r, k), modify2(rc, l, r, k);
pushup(o);
}
ll query3(int o, int l, int r)
{
if (tr[o].l>r||tr[o].r<l) return 0;
if (l<=tr[o].l&&tr[o].r<=r) return tr[o].sum;
pushdown(o);
return query3(lc, l, r)+query3(rc, l, r);
}
int query4(int o, int l, int r)
{
if (tr[o].l>r||tr[o].r<l) return -INF;
if (l<=tr[o].l&&tr[o].r<=r) return tr[o].mx;
pushdown(o);
return max(query4(lc, l, r), query4(rc, l, r));
}
int query5(int o, int l, int r)
{
if (tr[o].l>r||tr[o].r<l) return -INF;
if (l<=tr[o].l&&tr[o].r<=r) return tr[o].mx_;
pushdown(o);
return max(query5(lc, l, r), query5(rc, l, r));
}
#undef lc
#undef rc
} sgt;
int a[MAXN];
int main()
{
// freopen("segbeats.in", "r", stdin);
// freopen("segbeats.out", "w", stdout);
int n=read(), m=read();
for (int i=1; i<=n; i++) a[i]=read();
sgt.build(1, 1, n, a);
while (m--)
{
int op=read(), l, r, k;
switch (op)
{
case 1:
l=read(), r=read(), k=read();
sgt.modify1(1, l, r, k);
break;
case 2:
l=read(), r=read(), k=read();
sgt.modify2(1, l, r, k);
break;
case 3:
l=read(), r=read();
printf("%lld\n", sgt.query3(1, l, r));
break;
case 4:
l=read(), r=read();
printf("%d\n", sgt.query4(1, l, r));
break;
case 5:
l=read(), r=read();
printf("%d\n", sgt.query5(1, l, r));
break;
}
}
return 0;
}