题解:P14268 [ROI 2015 Day2] 路灯

· · 题解

维护一个长度为 n01 序列,支持 q 次以下操作:

初始时和每次操作后均需输出所有目前或曾经区间内元素全部为 1 的区间个数。

*** 考虑如何维护历史 $1$ 区间个数。 考虑把贡献拆开,我们固定一个右端点,对于每一个右端点 $i$,记录 $f(i)$ 为满足历史上存在 $[f(i), i]$ 全为 $1$ 的最小数。答案就是 $\sum\limits_{i = 1}^n(i - f(i)+1) = \sum\limits_{i=1}^n(i+1)-\sum\limits_{i = 1}^nf(i)$。显然答案第一项为 $\dfrac{n(n+3)}{2}$。 这样我们就把维护历史 $1$ 区间个数转换为区间操作 $f(i) \gets \min\{f(i), x\}$ 和区间求和。这是 Segment Tree Beats 经典操作。 考虑每次操作中 Segment Tree Beats 的修改区间是什么: - $c = 0$:无需操作。 - $c = 1$:我们把序列分成 $3$ 段:$[1, l], (l, r), [r, n]$,显然本次操作后 $(l, r)$ 内所有元素均为 $1$,因此我们只需考虑 $[1, l]$ 后缀的起始位置以及 $[r, n]$ 前缀的结束位置。 维护 $1$ 区间前缀、后缀是简单的。线段树维护即可。 然后我们记 $[1, l]$ 后缀的起始位置为 $L$,$[r, n]$ 前缀的结束位置为 $R$。每次修改操作即 $\forall L \le i \le R, f(i) = \min\{f(i), L\}$。 时间复杂度为 $O(n \log n)$,空间复杂度 $O(n)$,足以通过此题。 ::::success[代码] ```cpp #include <bits/stdc++.h> #define endl '\n' #define mid ((l + r) >> 1) #define ls (pos << 1) #define rs (pos << 1 | 1) using i32 = int; using i64 = long long; using i128 = __int128; using f64 = double; using f128 = long double; using p32 = std::pair<i32, i32>; using p64 = std::pair<i64, i64>; const i32 inf32 = 2e9; const i64 inf64 = 2e18; void solve(); signed main() { std::cin.tie(nullptr)->sync_with_stdio(false); i32 _ = 1; while (_--) solve(); return 0; } const i32 N = 3e5 + 5; i32 n, q; std::string s; struct node1 { i32 len, pre, suf, tag; } t1[N << 2]; struct node2 { i64 mx, smx, cnt, sum; } t2[N << 2]; inline i64 C(i64 x) { return x * (x + 1) / 2; } void pushup1(i32 pos) { t1[pos].len = t1[ls].len + t1[rs].len; t1[pos].pre = t1[ls].pre == t1[ls].len ? t1[ls].len + t1[rs].pre : t1[ls].pre; t1[pos].suf = t1[rs].suf == t1[rs].len ? t1[rs].len + t1[ls].suf : t1[rs].suf; } void apply1(i32 pos, i32 v) { if (v == 0) t1[pos].pre = t1[pos].suf = 0; else t1[pos].pre = t1[pos].suf = t1[pos].len; t1[pos].tag = v; } void build1(i32 l, i32 r, i32 pos) { t1[pos].len = r - l + 1, t1[pos].tag = -1; if (l == r) return t1[pos].pre = t1[pos].suf = s[l] - 48, void(); build1(l, mid, ls), build1(mid + 1, r, rs), pushup1(pos); } void pushdown1(i32 pos) { if (t1[pos].tag != -1) { apply1(ls, t1[pos].tag); apply1(rs, t1[pos].tag); t1[pos].tag = -1; } } void update1(i32 st, i32 ed, i32 v, i32 l, i32 r, i32 pos) { if (st <= l && r <= ed) return apply1(pos, v); pushdown1(pos); if (st <= mid) update1(st, ed, v, l, mid, ls); if (mid + 1 <= ed) update1(st, ed, v, mid + 1, r, rs); pushup1(pos); } node1 query1(i32 st, i32 ed, i32 l, i32 r, i32 pos) { if (st <= l && r <= ed) return t1[pos]; pushdown1(pos); if (ed <= mid) return query1(st, ed, l, mid, ls); else if (st > mid) return query1(st, ed, mid + 1, r, rs); else { node1 L = query1(st, ed, l, mid, ls), R = query1(st, ed, mid + 1, r, rs), ret; ret.len = L.len + R.len; ret.pre = L.pre == L.len ? L.len + R.pre : L.pre; ret.suf = R.suf == R.len ? R.len + L.suf : R.suf; ret.tag = -1; return ret; } } void pushup2(i32 pos) { t2[pos].sum = t2[ls].sum + t2[rs].sum; if (t2[ls].mx == t2[rs].mx) t2[pos] = {t2[ls].mx, std::max(t2[ls].smx, t2[rs].smx), t2[ls].cnt + t2[rs].cnt, t2[ls].sum + t2[rs].sum}; else if (t2[ls].mx > t2[rs].mx) t2[pos] = {t2[ls].mx, std::max(t2[ls].smx, t2[rs].mx), t2[ls].cnt, t2[ls].sum + t2[rs].sum}; else t2[pos] = {t2[rs].mx, std::max(t2[ls].mx, t2[rs].smx), t2[rs].cnt, t2[ls].sum + t2[rs].sum}; } void build2(i32 l, i32 r, i32 pos) { if (l == r) return t2[pos] = {(i64)l + 1, (i64)-4e18, 1, (i64)l + 1}, void(); build2(l, mid, ls), build2(mid + 1, r, rs), pushup2(pos); } void apply2(i32 pos, i64 x) { if (x >= t2[pos].mx) return; t2[pos].sum -= (t2[pos].mx - x) * t2[pos].cnt; t2[pos].mx = x; } void pushdown2(i32 pos) { if (t2[ls].mx > t2[pos].mx) apply2(ls, t2[pos].mx); if (t2[rs].mx > t2[pos].mx) apply2(rs, t2[pos].mx); } void chmin2(i32 st, i32 ed, i64 x, i32 l, i32 r, i32 pos) { if (x >= t2[pos].mx) return; if (st <= l && r <= ed && x > t2[pos].smx) return apply2(pos, x); if (l == r) return apply2(pos, x); pushdown2(pos); if (st <= mid) chmin2(st, ed, x, l, mid, ls); if (mid + 1 <= ed) chmin2(st, ed, x, mid + 1, r, rs); pushup2(pos); } void solve() { std::cin >> n >> q >> s; s = ' ' + s; build1(1, n, 1); build2(1, n, 1); i64 base = (i64)n * (n + 3) / 2; for (i32 i = 1; i <= n;) { if (s[i] == '1') { i32 j = i; while (j <= n && s[j] == '1') ++j; chmin2(i, j - 1, i, 1, n, 1); i = j; } else ++i; } std::cout << base - t2[1].sum << endl; for (i32 i = 1, l, r, c; i <= q; i++) { std::cin >> l >> r >> c; if (c == 0) update1(l, r, 0, 1, n, 1); else { update1(l, r, 1, 1, n, 1); i32 L = l - query1(1, l, 1, n, 1).suf + 1; i32 R = r + query1(r, n, 1, n, 1).pre - 1; if (L <= R) chmin2(L, R, L, 1, n, 1); } std::cout << base - t2[1].sum << endl; } } ``` ::::