题解 P6242 【【模板】线段树3】

· · 题解

Update 3/29:修正了一些笔误和可能令人误解地方。

关于线段树维护区间最值操作和区间历史最值的完整方法,请参阅吉如一 2016 年的国家集训队论文 和 我的洛谷日报。这篇题解大部分内容是复制于后者。谢谢大家给我的日报捧场 ^_^

概述

首先假设我们有一个序列 A。那么,

在这道题中,操作 2 是区间最小值操作,操作 5 是求历史最大值的区间最大值。

题解

首先,让我们考虑只有操作 2,3,4 的情况。我们发现,关于区间最大值的询问是可以使用带懒标记的线段树维护的,而区间和则不是很容易求出,因为这不能在下传标记时快速更新答案。

考虑到区间取 \min 的操作只会对最大值不超过 k 的节点产生影响,我们可以在这方面产生思路。为了使复杂度变对,线段树的一个节点要维护三个信息:区间最大值 mx,区间严格次大值 se 和最大值的个数 cnt。那么,一次区间最值操作作用在这个节点上时,可以被分为以下三种情况:

举个例子,现在要以 k=2 对下面这棵线段树维护的区间取 \min。每个节点左侧表示区间最大值,右侧表示严格次大值。(很抱歉用了吉老师论文里的例子,因为我实在找不出比它更适合做例子的例子了)

按照上面描述的操作,我们应该沿着红色的边 DFS,最后在红色的节点上更新区间和并打上标记。

这样做的复杂度可以证明是 O(m\log n) 的。原论文的证明使用了势能分析,这里给出一种更好懂的 O((n+m)\log n) 下界的证明:

显然只有暴力 DFS 的过程会影响复杂度,我们考虑一次区间最值操作中哪些节点会被搜索到。发现到达的节点区间中不同的数的个数(下称“值域”)一定会减少(因为至少将最大值与次大值合并了)。线段树每层节点表示的区间的值域一共是 O(n) 的,一共 \log 层加起来就是 O(n\log n)。再加上每次操作的复杂度,可以得到复杂度的一个下界为 O((n+m)\log n)

然后,我们加入区间加减操作(操作 1)。发现上面方法的本质是把线段树上每个节点维护的元素分成了最大值和非最大值这两类,并对它们分别维护。区间最大值标记只会打在 se<v<mx 的节点上,可以看成只针对最大值的加减操作;而区间加操作对于这两类值都生效。分别维护区间最大值的加减标记和非最大值的加减标记,下传时判断当前节点是否为区间最大值即可。

由于区间加减操作,某些节点的值域会增大。一次区间加减最多会影响 O(\log^2n) 个节点的值域,所以用之前的证明方法可证复杂度的一个上界为 O(n\log n+m\log^2n)(论文里给的是 O(m\log^2n))。而实际实现时会发现这个上界其实往往是跑不满的,速度几乎与大常数一个 \log 接近。

目前为止,我们已经解决了测试点 3,4。在此总结一下,上面的这种处理区间最值的方法的核心是数域的划分。也就是说,我们以一个 \log 的代价,将区间最值操作转化成了最值数域上的区间加减操作。

现在我们再考虑只有操作 1,3,4,5 的情况。 我们在线段树的每个节点上维护两个值:区间当前最大值 mx 和历史最大值 mx',以及两个标记:区间加标记 add 和历史最大加减标记 add'。这里唯一不容易理解的就是“历史最大加减标记”,它表示的是从上一次下传 add 到现在,区间加标记达到的最大值

考虑历史最值标记的合并。当节点 x 下传标记到它的儿子 ch 时,儿子历史最大值应与当前值取 \max,写成公式就是下面的样子:

add'_{ch}\gets\max(add'_{ch},add_{ch}+add'_x) mx'_{ch}\gets\max(mx'_{ch},mx_{ch}+add'_x)

这样做的正确性是显然的,区间和可以用常规方法维护,复杂度为 O(m\log n)。于是我们又解决了测试点 5,6

现在考虑将这两部分合并起来,在这里我们应用之前的思想,把一个区间的元素划分成最大值和非最大值两部分。那么区间最值操作可以转化为对最大值的加减操作,而题目中的区间加减则同时作用于这两类值。于是我们以一个 \log 的代价把问题就变成了两个数域上的区间加减操作和历史最值询问。具体地说,我们需要维护四种标记:

前两个标记是只修改最大值的,所以下传时要判断当前节点是否包含区间最大值。总复杂度为 O(m\log^2 n),需要注意一下常数。

如果有地方不太理解的话,可以看下面的代码:

#include <cstdio>
#include <cctype>
#include <algorithm>
using namespace std;
typedef long long ll;
const int MAXN=550000;
const int INF=2E9+233;
int read()
{
    int x=0, w=0; char ch=0;
    while (!isdigit(ch)) w|=ch=='-', ch=getchar();
    while (isdigit(ch)) x=(x<<3)+(x<<1)+(ch^48), ch=getchar();
    return w?-x:x;
}
struct SegmentTree
{
    struct Node
    {
        int l, r;
        int mx, mx_, se, cnt; ll sum;
        int add1, add1_, add2, add2_;
    } tr[4*MAXN];
    #define lc (o<<1)
    #define rc (o<<1|1)
    void pushup(int o)
    {
        tr[o].sum=tr[lc].sum+tr[rc].sum;
        tr[o].mx_=max(tr[lc].mx_, tr[rc].mx_);
        if (tr[lc].mx==tr[rc].mx)
        {
            tr[o].mx=tr[lc].mx;
            tr[o].se=max(tr[lc].se, tr[rc].se);
            tr[o].cnt=tr[lc].cnt+tr[rc].cnt;
        }
        else if (tr[lc].mx>tr[rc].mx)
        {
            tr[o].mx=tr[lc].mx;
            tr[o].se=max(tr[lc].se, tr[rc].mx);
            tr[o].cnt=tr[lc].cnt;
        }
        else
        {
            tr[o].mx=tr[rc].mx;
            tr[o].se=max(tr[lc].mx, tr[rc].se);
            tr[o].cnt=tr[rc].cnt;
        }
    }
    void update(int o, int k1, int k1_, int k2, int k2_)
    {
        tr[o].sum+=1ll*k1*tr[o].cnt+1ll*k2*(tr[o].r-tr[o].l+1-tr[o].cnt);
        tr[o].mx_=max(tr[o].mx_, tr[o].mx+k1_);
        tr[o].add1_=max(tr[o].add1_, tr[o].add1+k1_);
        tr[o].mx+=k1, tr[o].add1+=k1;
        tr[o].add2_=max(tr[o].add2_, tr[o].add2+k2_);
        if (tr[o].se!=-INF) tr[o].se+=k2;
        tr[o].add2+=k2;
    }
    void pushdown(int o)
    {
        int tmp=max(tr[lc].mx, tr[rc].mx);
        if (tr[lc].mx==tmp)
            update(lc, tr[o].add1, tr[o].add1_, tr[o].add2, tr[o].add2_);
        else update(lc, tr[o].add2, tr[o].add2_, tr[o].add2, tr[o].add2_);
        if (tr[rc].mx==tmp)
            update(rc, tr[o].add1, tr[o].add1_, tr[o].add2, tr[o].add2_);
        else update(rc, tr[o].add2, tr[o].add2_, tr[o].add2, tr[o].add2_);
        tr[o].add1=tr[o].add1_=tr[o].add2=tr[o].add2_=0;
    }
    void build(int o, int l, int r, int* a)
    {
        tr[o].l=l, tr[o].r=r;
        tr[o].add1=tr[o].add1_=tr[o].add2=tr[o].add2_=0;
        if (l==r)
        {
            tr[o].sum=tr[o].mx_=tr[o].mx=a[l];
            tr[o].se=-INF, tr[o].cnt=1;
            return;
        }
        int mid=l+r>>1;
        build(lc, l, mid, a);
        build(rc, mid+1, r, a);
        pushup(o);
    }
    void modify1(int o, int l, int r, int k)
    {
        if (tr[o].l>r||tr[o].r<l) return;
        if (l<=tr[o].l&&tr[o].r<=r)
            { update(o, k, k, k, k); return; }
        pushdown(o);
        modify1(lc, l, r, k), modify1(rc, l, r, k);
        pushup(o);
    }
    void modify2(int o, int l, int r, int k)
    {
        if (tr[o].l>r||tr[o].r<l||k>=tr[o].mx) return;
        if (l<=tr[o].l&&tr[o].r<=r&&k>tr[o].se)
            { update(o, k-tr[o].mx, k-tr[o].mx, 0, 0); return; }
        pushdown(o);
        modify2(lc, l, r, k), modify2(rc, l, r, k);
        pushup(o);
    }
    ll query3(int o, int l, int r)
    {
        if (tr[o].l>r||tr[o].r<l) return 0;
        if (l<=tr[o].l&&tr[o].r<=r) return tr[o].sum;
        pushdown(o);
        return query3(lc, l, r)+query3(rc, l, r);
    }
    int query4(int o, int l, int r)
    {
        if (tr[o].l>r||tr[o].r<l) return -INF;
        if (l<=tr[o].l&&tr[o].r<=r) return tr[o].mx;
        pushdown(o);
        return max(query4(lc, l, r), query4(rc, l, r));
    }
    int query5(int o, int l, int r)
    {
        if (tr[o].l>r||tr[o].r<l) return -INF;
        if (l<=tr[o].l&&tr[o].r<=r) return tr[o].mx_;
        pushdown(o);
        return max(query5(lc, l, r), query5(rc, l, r));
    }
    #undef lc
    #undef rc
} sgt;
int a[MAXN];
int main()
{
//  freopen("segbeats.in", "r", stdin);
//  freopen("segbeats.out", "w", stdout);
    int n=read(), m=read();
    for (int i=1; i<=n; i++) a[i]=read();
    sgt.build(1, 1, n, a);
    while (m--)
    {
        int op=read(), l, r, k;
        switch (op)
        {
            case 1:
                l=read(), r=read(), k=read();
                sgt.modify1(1, l, r, k);
                break;
            case 2:
                l=read(), r=read(), k=read();
                sgt.modify2(1, l, r, k);
                break;
            case 3:
                l=read(), r=read();
                printf("%lld\n", sgt.query3(1, l, r));
                break;
            case 4:
                l=read(), r=read();
                printf("%d\n", sgt.query4(1, l, r));
                break;
            case 5:
                l=read(), r=read();
                printf("%d\n", sgt.query5(1, l, r));
                break;
        }
    }
    return 0;
}