题解:P9877 [EC Final 2021] Vacation

· · 题解

套路地,将数列每 c 个划分为一个块,那么答案的区间要么在一个块内,要么在两个相邻的两个块内。

如果答案在一个块内,这是好处理的,我们处理一个线段树单独维护(也就是开 \lceil\frac{n}{c}\rceil 棵大小为 c 的线段树)每个块的前缀最大值、后缀最大值、区间和以及子段最大值,合并就是左端点后缀最大值与右端点前缀最小值相加取最大值就好了。

因为查询可能涉及多个块,所以我们需要再有一颗线段树维护连续块的答案。

然后观察到相邻的两个块的情况:

如果 j-i+1\le c,要求满足什么?

我们把右边的块拿下来:

如果原本满足 j-i+1\le c,若我们把每个块的元素重新编号,那么 j<i

那么我们可以维护区间的前缀最大值(是下一个段的),后缀最大值(是这一个段的),区间和,合并的时候就是左端点前缀最大值(下一个段的)与右端点后缀最小值(这一个段的)相加取最大值就好了。

然后再有一颗线段树维护维护连续块的答案。

考虑查询,我们直接处理 r-l+1\le c 的询问,这可以直接做不限制的子段最大值(当然,这可能需要你再开一颗线段树)。

那么对于左右端点处在不相邻的段的情况,我们直接做中间的区间,然后左端点就查询左端点在 [l,L] 的跨段情况,其中 L 代表左端点所在块的右端点。然后右端点就查询左端点在 [R,r-c] 的跨段情况,其中 R 代表 r-c 所在块的左端点。

然后发现可能还有左右端点不跨块的情况,我们再查询 [l,l+c-1][r-c+1,r] 的情况就好了。

对于左右端点处在相邻的段的情况,我们直接查左端点在 [l,r-c] 的跨段情况,以及 [l,l+c-1][r-c+1,r] 的情况就好了。

对着题解看感受不出来的,可以先尝试着写一写,然后你就会了。

代码无法通过加强版,需要你把 cin 改成快读。

#include <bits/stdc++.h>
#define N 400010
#define pos Index(l, r)
#define ls Index(l, mid)
#define rs Index(mid + 1, r)
#define int long long
using namespace std;
int n, m, c, a[N];
int L[N], R[N], idx[N];      // 块的数据
inline int Index(int l, int r) { return l + r | l != r; }
// 区间最大子段和数据
struct data1 {
    int mx, pre, suf, sum;
    data1(int x = 0):mx(max(x, 0ll)), pre(max(x, 0ll)), suf(max(x, 0ll)), sum(x) {}
    friend data1 operator+(const data1& a, const data1& b) {
        data1 c;
        c.mx = max(a.mx, max(b.mx, a.suf + b.pre));
        c.pre = max(a.pre, a.sum + b.pre);
        c.suf = max(b.suf, a.suf + b.sum);
        c.sum = a.sum + b.sum;
        return c;
    }
};
// 前后段和数据
struct data2 {
    // suf是这一段的后缀和,pre是下一段的前缀和,sum1是这一段的和,sum2是下一段的和
    int mx, pre, suf, sum1, sum2;
    data2(int x = 0, int y = 0):mx(max(x, 0ll)), pre(max(y, 0ll)), suf(max(x, 0ll)), sum1(x), sum2(y) {}
    friend data2 operator+(const data2& a, const data2& b) {
        data2 c;
        c.mx = max(a.mx + b.sum1, max(a.sum2 + b.mx, a.pre + b.suf));
        c.pre = max(a.pre, a.sum2 + b.pre);
        c.suf = max(b.suf, b.sum1 + a.suf);
        c.sum1 = a.sum1 + b.sum1;
        c.sum2 = a.sum2 + b.sum2;
        return c;
    }
};
// 全局线段树,维护区间最大子段和
namespace sg1 {
    data1 t[N * 2];
    void build(int l, int r) {
        if (l == r) return (void)(t[pos] = data1(a[l]));
        int mid = (l + r) >> 1;
        build(l, mid);
        build(mid + 1, r);
        t[pos] = t[ls] + t[rs];
    }
    void update(int x, int l, int r) {
        if (l == r) return (void)(t[pos] = data1(a[l]));
        int mid = (l + r) >> 1;
        if (x <= mid) update(x, l, mid);
        else update(x, mid + 1, r);
        t[pos] = t[ls] + t[rs];
    }
    data1 query(int nl, int nr, int l, int r) {
        if (nl > nr) return 0;
        if (nl <= l && r <= nr) return t[pos];
        int mid = (l + r) >> 1;
        data1 res;
        bool flag = 0;
        if (nl <= mid) res = query(nl, nr, l, mid), flag = 1;
        if (mid < nr) {
            if(flag) res = res + query(nl, nr, mid + 1, r);
            else res = query(nl, nr, mid + 1, r);
        }
        return res;
    }
}
// 局部线段树,维护每个块内的元素data1和data2
namespace sg2 {
    data1 t1[N * 2];
    data2 t2[N * 2];
    void build(int l, int r) {
        if (l == r) return (void)(t1[pos] = data1(a[l]), t2[pos] = data2(a[l], a[l + c]));
        int mid = (l + r) >> 1;
        build(l, mid);
        build(mid + 1, r);
        t1[pos] = t1[ls] + t1[rs], t2[pos] = t2[ls] + t2[rs];
    }
    void update(int x, int l, int r) {
        if (l == r) return (void)(t1[pos] = data1(a[l]), t2[pos] = data2(a[l], a[l + c]));
        int mid = (l + r) >> 1;
        if (x <= mid) update(x, l, mid);
        else update(x, mid + 1, r);
        t1[pos] = t1[ls] + t1[rs], t2[pos] = t2[ls] + t2[rs];
    }
    data1 query1(int nl, int nr, int l, int r) {
        if (nl <= l && r <= nr) return t1[pos];
        int mid = (l + r) >> 1;
        data1 res;
        bool flag = 0;
        if (nl <= mid) res = query1(nl, nr, l, mid), flag = 1;
        if (mid < nr) {
            if(flag) res = res + query1(nl, nr, mid + 1, r);
            else res = query1(nl, nr, mid + 1, r);
        }
        return res;
    }
    data2 query2(int nl, int nr, int l, int r) {
        if (nl <= l && r <= nr) return t2[pos];
        int mid = (l + r) >> 1;
        data2 res;
        bool flag = 0;
        if (nl <= mid) res = query2(nl, nr, l, mid), flag = 1;
        if (mid < nr) {
            if(flag) res = res + query2(nl, nr, mid + 1, r);
            else res = query2(nl, nr, mid + 1, r);
        }
        return res;
    }
}
// 维护局部信息的全局线段树,维护每个块内的答案
namespace sg3 {
    int t[N * 2];
    void build(int l, int r) {
        if (l == r) return (void)(t[pos] = max(sg2::query1(L[l], R[l], L[l], R[l]).mx, sg2::query2(L[l], R[l], L[l], R[l]).mx));
        int mid = (l + r) >> 1;
        build(l, mid);
        build(mid + 1, r);
        t[pos] = max(t[ls], t[rs]);
    }
    void update(int x, int l, int r) {
        if (l == r) return (void)(t[pos] = max(sg2::query1(L[l], R[l], L[l], R[l]).mx, sg2::query2(L[l], R[l], L[l], R[l]).mx));
        int mid = (l + r) >> 1;
        if (x <= mid) update(x, l, mid);
        else update(x, mid + 1, r);
        t[pos] = max(t[ls], t[rs]);
    }
    int query(int nl, int nr, int l, int r) {
        if (nl <= l && r <= nr) return t[pos];
        int mid = (l + r) >> 1;
        int res = 0;
        if (nl <= mid) res = max(res, query(nl, nr, l, mid));
        if (mid < nr) res = max(res, query(nl, nr, mid + 1, r));
        return res;
    }
}
inline int query(int l, int r) {
    if(r - l + 1 <= c) return sg1::query(l, r, 1, n).mx;
    int ans = 0;
    if(idx[l] + 1 < idx[r] - 1) ans = max(ans, sg3::query(idx[l] + 1, idx[r] - 2, 1, n / c));
    if(idx[l] + 1 <= idx[r] - 1) ans = max(ans, sg2::query1(L[idx[r] - 1], R[idx[r] - 1], L[idx[r] - 1], R[idx[r] - 1]).mx);
    if(idx[l] + 1 == idx[r]) {
        ans = max(ans, sg2::query2(l, r - c, L[idx[l]], R[idx[l]]).mx + sg1::query(r - c + 1, l + c - 1, 1, n).sum);
    } else {
        ans = max(ans, sg2::query2(l, R[idx[l]], L[idx[l]], R[idx[l]]).mx + sg1::query(L[idx[l] + 1], l + c - 1, 1, n).sum);
        ans = max(ans, sg2::query2(L[idx[r] - 1], r - c, L[idx[r] - 1], R[idx[r] - 1]).mx + sg1::query(r - c + 1, R[idx[r] - 1], 1, n).sum);
    }
    ans = max(ans, sg1::query(l, l + c - 1, 1, n).mx);
    ans = max(ans, sg1::query(r - c + 1, r, 1, n).mx);
    return ans;
}
signed main() {
    cin >> n >> m >> c;
    for (int i = 1; i <= n; i++) cin >> a[i];
    n = (n + c - 1) / c * c;
    for (int i = 1; i <= n; i++) idx[i] = (i - 1) / c + 1, R[idx[i]] = i;
    for (int i = n; i >= 1; i--) L[idx[i]] = i;
    sg1::build(1, n);
    for (int i = 1; i <= n / c; i++) sg2::build(L[i], R[i]);
    sg3::build(1, n / c);
    while (m--) {
        int op, x, y;
        cin >> op >> x >> y;
        if (op == 1) {
            a[x] = y;
            sg1::update(x, 1, n);
            sg2::update(x, L[idx[x]], R[idx[x]]);
            sg3::update(idx[x], 1, n / c);
            if (x > c) sg2::update(x - c, L[idx[x] - 1], R[idx[x] - 1]);
            if (x > c) sg3::update(idx[x] - 1, 1, n / c);
        } else {
            cout << query(x, y) << endl;
        }
    }
}