「解题报告」P9334 [JOISC 2023 Day2] Mizuyokan 2

· · 题解

首先大力 DP 就是 O(n^3),注意到任意一种方案把小段缩到长度为 1 仍然满足条件,于是我们可以只记录每次选的小段的位置,这样就能 O(n^2) 了,且容易优化到 O(n \log n),但是这个做法没有什么拓展空间。

考虑分析大段的性质。首先大段有一个显然的必要条件,就是大段的和一定比段的左右两个数大,即 \sum_{i=l}^r a_i > \max(a_{l-1}, a_{r+1}),我们称满足这个条件的段为好段,而最优方案下小段长度为 1,于是上述条件就是一个充分条件。但是小段长度不一定等于 1,不过我们可以猜测,小段长度不等于 1 时,上述条件仍然是充分条件。

证明很简单,考虑左右两个段 [l_1, r_1], [l_2, r_2],如果有 a_{r_1 + 1} < a_{l_2 - 1},那么将 l_2 拓展到 r_1 + 2 显然仍然满足条件,否则将 r_1 拓展到 l_2 - 2 仍然满足条件。对于最靠左与最靠右的段,我们假设最靠左的段 [l, r],右边同理,那么如果 a_1 < a_{l - 1},那么把 l 拓展到 2 显然符合条件,否则可以在左面新加一段 [1, l - 1],这样也能够符合条件。也就是说,任意一种好段的划分,都可以找到一种合法方案的划分,那么我们只需要考虑找出最大的好段划分即可,这显然是等于答案的。而这个问题有很简单的贪心方式解决,对于右端点从左往右扫,如果左端点大于等于上一个右端点 +2 那么就选这个段。而这个做法就很好拓展了。

那么考虑设 f(r) 为以 r 为右端点的最小的好段的左端点,设 nxt(r) 为最小的 k 满足 f(k) \ge r + 2,答案就是从 [l, r] 中某个点开始跳 nxt,跳到区间外为止,最多能跳多少个点。需要特殊讨论一下开头结尾是不是长为 1 的小段,这都比较简单。那么静态区间询问就可以做了。

考虑如何修改。每次修改看起来会改变很多的 f(r),并不好修改,但是可以证明 nxt(r) - r = O(\log V),考虑从右边某个点开始向左右拓展,注意到向一边拓展时,如果 a_{l - 1} < \sum a_i,那么此时就可以停止拓展,否则有 a_{l - 1} \ge \sum a_i,注意到拓展后 \sum a_i 会翻倍,于是每次这样拓展都会使得区间和翻倍,那么显然翻倍 O(\log V) 次后就一定会满足条件了,所以说这样的拓展向左向右都最多进行 O(\log V) 次,于是就可以证明 nxt(r) - r = O(\log V) 了。

那么也就是说只有 O(\log V)nxt(r) 会发生改变,于是我们只需要每次更新这几个 nxt(r) 即可。然后维护跳链的信息可以使用 LCT 做到 O(n \log n \log V),不过这比较蠢,由于 nxt(r) - r = O(\log V),也就是说每次只会跳 O(\log V),于是我们可以直接用线段树维护一个区间内从前 O(\log V) 个为止开始跳,跳到区间外的第一个点是啥,然后这个东西是容易合并的,单次 pushup 是 O(\log V) 的,而我们修改连续的 O(\log V) 个值,所以只会更新 O(\log n + \log V) 个点,于是线段树维护的复杂度就是 O(n \log V (\log n + \log V))

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 250005, LOG = 65;
int n, q, a[MAXN];
int nxt[MAXN];
struct SegmentTree {
    struct Node {
        int l, r;
        pair<int, int> a[LOG + 1];
        pair<int, int> jmp(pair<int, int> q) {
            auto [i, v] = q;
            auto p = i <= min(r - l + 1, LOG) ? a[i] : make_pair(i - (r - l + 1), 0);
            p.second += v;
            return p;
        }
        Node operator+(Node b) {
            Node c; c.l = l, c.r = b.r;
            for (int i = 1; i <= min(c.r - c.l + 1, LOG); i++) c.a[i] = b.jmp(jmp(make_pair(i, 0)));
            return c;
        }
    } t[MAXN << 2];
#define lc (i << 1)
#define rc (i << 1 | 1)
    void update(int a, int b, int i = 1, int l = 0, int r = n) {
        if (l == r) {
            t[i].l = t[i].r = l;
            t[i].a[1] = { nxt[l] - l, 1 };
            return;
        }
        int mid = (l + r) >> 1;
        if (a <= mid) update(a, b, lc, l, mid);
        if (b > mid) update(a, b, rc, mid + 1, r);
        t[i] = t[lc] + t[rc];
    }
    Node query(int a, int b, int i = 1, int l = 0, int r = n) {
        if (a <= l && r <= b) return t[i];
        int mid = (l + r) >> 1;
        if (b <= mid) return query(a, b, lc, l, mid);
        if (a > mid) return query(a, b, rc, mid + 1, r);
        return query(a, b, lc, l, mid) + query(a, b, rc, mid + 1, r);
    }
} st;
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    a[0] = a[n + 1] = -1;
    nxt[n + 1] = n + 1;
    int cc = 0;
    auto solve = [&](int l, int r) {
        for (int i = r; i >= l; i--) {
            nxt[i] = nxt[i + 1];
            long long sum = 0;
            for (int j = i + 2; j <= min(n, i + LOG); j++) {
                sum += a[j];
                cc++;
                if (sum > a[i + 1] && sum > a[j + 1]) {
                    nxt[i] = min(nxt[i], j);
                    break;
                }
            }
        }
        st.update(l, r);
    };
    solve(0, n);
    scanf("%d", &q);
    while (q--) {
        {
            int x, y; scanf("%d%d", &x, &y);
            a[x] = y;
            solve(max(0, x - LOG), x);
        }
        {
            int l, r; scanf("%d%d", &l, &r); l++;
            int ans = 1;
            int pre = r, suf = l;
            long long sum = 0;
            for (int i = l; i < r; i++) {
                sum += a[i];
                if (a[i + 1] < sum) {
                    pre = i;
                    break;
                }
            }
            sum = 0;
            for (int i = r; i > l; i--) {
                sum += a[i];
                if (a[i - 1] < sum) {
                    suf = i;
                    break;
                }
            }
            auto query = [&](int l, int r) {
                if (l > r && nxt[l - 1] >= r) return INT_MIN;                
                return 2 * (st.query(l - 1, r - 1).a[1].second - 1) + 1;
            };
            ans = max(ans, query(l, r));
            ans = max(ans, query(pre + 1, r) + 1);
            ans = max(ans, query(l, suf - 1) + 1);
            ans = max(ans, query(pre + 1, suf - 1) + 2);
            printf("%d\n", ans);
        }
    }
    return 0;
}