题解:P9877 [EC Final 2021] Vacation
套路地,将数列每
如果答案在一个块内,这是好处理的,我们处理一个线段树单独维护(也就是开
因为查询可能涉及多个块,所以我们需要再有一颗线段树维护连续块的答案。
然后观察到相邻的两个块的情况:
如果
我们把右边的块拿下来:
如果原本满足
那么我们可以维护区间的前缀最大值(是下一个段的),后缀最大值(是这一个段的),区间和,合并的时候就是左端点前缀最大值(下一个段的)与右端点后缀最小值(这一个段的)相加取最大值就好了。
然后再有一颗线段树维护维护连续块的答案。
考虑查询,我们直接处理
那么对于左右端点处在不相邻的段的情况,我们直接做中间的区间,然后左端点就查询左端点在
然后发现可能还有左右端点不跨块的情况,我们再查询
对于左右端点处在相邻的段的情况,我们直接查左端点在
对着题解看感受不出来的,可以先尝试着写一写,然后你就会了。
代码无法通过加强版,需要你把 cin 改成快读。
#include <bits/stdc++.h>
#define N 400010
#define pos Index(l, r)
#define ls Index(l, mid)
#define rs Index(mid + 1, r)
#define int long long
using namespace std;
int n, m, c, a[N];
int L[N], R[N], idx[N]; // 块的数据
inline int Index(int l, int r) { return l + r | l != r; }
// 区间最大子段和数据
struct data1 {
int mx, pre, suf, sum;
data1(int x = 0):mx(max(x, 0ll)), pre(max(x, 0ll)), suf(max(x, 0ll)), sum(x) {}
friend data1 operator+(const data1& a, const data1& b) {
data1 c;
c.mx = max(a.mx, max(b.mx, a.suf + b.pre));
c.pre = max(a.pre, a.sum + b.pre);
c.suf = max(b.suf, a.suf + b.sum);
c.sum = a.sum + b.sum;
return c;
}
};
// 前后段和数据
struct data2 {
// suf是这一段的后缀和,pre是下一段的前缀和,sum1是这一段的和,sum2是下一段的和
int mx, pre, suf, sum1, sum2;
data2(int x = 0, int y = 0):mx(max(x, 0ll)), pre(max(y, 0ll)), suf(max(x, 0ll)), sum1(x), sum2(y) {}
friend data2 operator+(const data2& a, const data2& b) {
data2 c;
c.mx = max(a.mx + b.sum1, max(a.sum2 + b.mx, a.pre + b.suf));
c.pre = max(a.pre, a.sum2 + b.pre);
c.suf = max(b.suf, b.sum1 + a.suf);
c.sum1 = a.sum1 + b.sum1;
c.sum2 = a.sum2 + b.sum2;
return c;
}
};
// 全局线段树,维护区间最大子段和
namespace sg1 {
data1 t[N * 2];
void build(int l, int r) {
if (l == r) return (void)(t[pos] = data1(a[l]));
int mid = (l + r) >> 1;
build(l, mid);
build(mid + 1, r);
t[pos] = t[ls] + t[rs];
}
void update(int x, int l, int r) {
if (l == r) return (void)(t[pos] = data1(a[l]));
int mid = (l + r) >> 1;
if (x <= mid) update(x, l, mid);
else update(x, mid + 1, r);
t[pos] = t[ls] + t[rs];
}
data1 query(int nl, int nr, int l, int r) {
if (nl > nr) return 0;
if (nl <= l && r <= nr) return t[pos];
int mid = (l + r) >> 1;
data1 res;
bool flag = 0;
if (nl <= mid) res = query(nl, nr, l, mid), flag = 1;
if (mid < nr) {
if(flag) res = res + query(nl, nr, mid + 1, r);
else res = query(nl, nr, mid + 1, r);
}
return res;
}
}
// 局部线段树,维护每个块内的元素data1和data2
namespace sg2 {
data1 t1[N * 2];
data2 t2[N * 2];
void build(int l, int r) {
if (l == r) return (void)(t1[pos] = data1(a[l]), t2[pos] = data2(a[l], a[l + c]));
int mid = (l + r) >> 1;
build(l, mid);
build(mid + 1, r);
t1[pos] = t1[ls] + t1[rs], t2[pos] = t2[ls] + t2[rs];
}
void update(int x, int l, int r) {
if (l == r) return (void)(t1[pos] = data1(a[l]), t2[pos] = data2(a[l], a[l + c]));
int mid = (l + r) >> 1;
if (x <= mid) update(x, l, mid);
else update(x, mid + 1, r);
t1[pos] = t1[ls] + t1[rs], t2[pos] = t2[ls] + t2[rs];
}
data1 query1(int nl, int nr, int l, int r) {
if (nl <= l && r <= nr) return t1[pos];
int mid = (l + r) >> 1;
data1 res;
bool flag = 0;
if (nl <= mid) res = query1(nl, nr, l, mid), flag = 1;
if (mid < nr) {
if(flag) res = res + query1(nl, nr, mid + 1, r);
else res = query1(nl, nr, mid + 1, r);
}
return res;
}
data2 query2(int nl, int nr, int l, int r) {
if (nl <= l && r <= nr) return t2[pos];
int mid = (l + r) >> 1;
data2 res;
bool flag = 0;
if (nl <= mid) res = query2(nl, nr, l, mid), flag = 1;
if (mid < nr) {
if(flag) res = res + query2(nl, nr, mid + 1, r);
else res = query2(nl, nr, mid + 1, r);
}
return res;
}
}
// 维护局部信息的全局线段树,维护每个块内的答案
namespace sg3 {
int t[N * 2];
void build(int l, int r) {
if (l == r) return (void)(t[pos] = max(sg2::query1(L[l], R[l], L[l], R[l]).mx, sg2::query2(L[l], R[l], L[l], R[l]).mx));
int mid = (l + r) >> 1;
build(l, mid);
build(mid + 1, r);
t[pos] = max(t[ls], t[rs]);
}
void update(int x, int l, int r) {
if (l == r) return (void)(t[pos] = max(sg2::query1(L[l], R[l], L[l], R[l]).mx, sg2::query2(L[l], R[l], L[l], R[l]).mx));
int mid = (l + r) >> 1;
if (x <= mid) update(x, l, mid);
else update(x, mid + 1, r);
t[pos] = max(t[ls], t[rs]);
}
int query(int nl, int nr, int l, int r) {
if (nl <= l && r <= nr) return t[pos];
int mid = (l + r) >> 1;
int res = 0;
if (nl <= mid) res = max(res, query(nl, nr, l, mid));
if (mid < nr) res = max(res, query(nl, nr, mid + 1, r));
return res;
}
}
inline int query(int l, int r) {
if(r - l + 1 <= c) return sg1::query(l, r, 1, n).mx;
int ans = 0;
if(idx[l] + 1 < idx[r] - 1) ans = max(ans, sg3::query(idx[l] + 1, idx[r] - 2, 1, n / c));
if(idx[l] + 1 <= idx[r] - 1) ans = max(ans, sg2::query1(L[idx[r] - 1], R[idx[r] - 1], L[idx[r] - 1], R[idx[r] - 1]).mx);
if(idx[l] + 1 == idx[r]) {
ans = max(ans, sg2::query2(l, r - c, L[idx[l]], R[idx[l]]).mx + sg1::query(r - c + 1, l + c - 1, 1, n).sum);
} else {
ans = max(ans, sg2::query2(l, R[idx[l]], L[idx[l]], R[idx[l]]).mx + sg1::query(L[idx[l] + 1], l + c - 1, 1, n).sum);
ans = max(ans, sg2::query2(L[idx[r] - 1], r - c, L[idx[r] - 1], R[idx[r] - 1]).mx + sg1::query(r - c + 1, R[idx[r] - 1], 1, n).sum);
}
ans = max(ans, sg1::query(l, l + c - 1, 1, n).mx);
ans = max(ans, sg1::query(r - c + 1, r, 1, n).mx);
return ans;
}
signed main() {
cin >> n >> m >> c;
for (int i = 1; i <= n; i++) cin >> a[i];
n = (n + c - 1) / c * c;
for (int i = 1; i <= n; i++) idx[i] = (i - 1) / c + 1, R[idx[i]] = i;
for (int i = n; i >= 1; i--) L[idx[i]] = i;
sg1::build(1, n);
for (int i = 1; i <= n / c; i++) sg2::build(L[i], R[i]);
sg3::build(1, n / c);
while (m--) {
int op, x, y;
cin >> op >> x >> y;
if (op == 1) {
a[x] = y;
sg1::update(x, 1, n);
sg2::update(x, L[idx[x]], R[idx[x]]);
sg3::update(idx[x], 1, n / c);
if (x > c) sg2::update(x - c, L[idx[x] - 1], R[idx[x] - 1]);
if (x > c) sg3::update(idx[x] - 1, 1, n / c);
} else {
cout << query(x, y) << endl;
}
}
}