CF1787I Treasure Hunt

· · 题解

\color{skyblue}{\mathsf{Description}}

给定数组 b,令 v_b\max{(\sum_{i=1}^{q}{b_i}+\sum_{i=s}^{t}{b_i})}。其中,要么 t \le q,要么 s \ge q,前面和后面都可以为空。求数组 b 所有子数组的 v 之和。

\color{skyblue}{\mathsf{Solution}}

可以发现,t \le qs \ge q 这个条件是无效的,因为,当 s \le qq \le t 时,可以构造 s' = st' = qq' = t,而这一组是符合条件的。

有了这个性质,前面和后面就没有关系了,所以可以分开考虑,即求出所有的 最大前缀和 和所有的 最大子段和,再把二者加起来就是答案。

对于前者,即所有子数组的 最大前缀和。我们可以预处理出 pre 前缀和数组,对于 i 开头的子数组都要减去 pre_{i-1},所以只要最大的 pre_j。而这个可以用单调栈维护,每次二分出单调栈中第一个大于 pre_{i-1} 的位置 p,大于 p 的就是其对应的值,而小于 p 的就是 0

对于后者,即所有子数组的 最大子段和。考虑 cdq 分治,每次维护前半部分的 最大后缀和 lmax 和 最大后缀子段和 lans,后半部分的 最大前缀和 rmax 和 最大前缀子段和 rans。对于一个在前半部分的 l 和一个在后半部分的 r 组成的区间 [l,r] 的答案就是 \max(lans_l,rans_r,lmax_l+rmax_r),显然这四个数组都是单调不降的,lans_i-lmax_irans_i-max_i 的值也是单调不降的。于是我们可以从 mid 倒序枚举到 L,维护两个指针,分三种情况计算贡献即可。

\color{skyblue}{\mathsf{Code}}

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

const int N = 1e6 + 5;
const int mod = 998244353;

int n, top;
int a[N], stk[N];
ll ans;
ll pre[N], f[N];
ll lans[N], lmax[N];
ll rans[N], rmax[N];
ll prans[N], prmax[N];

void cdq(int L, int R) {
    if (L == R) return ans += (a[L] > 0 ? a[L] : 0), void();
    int mid = (L + R) >> 1;
    cdq(L, mid), cdq(mid + 1, R);
    lmax[mid + 1] = rmax[mid] = 0;
    lans[mid + 1] = rans[mid] = 0;
    for (int i = mid; i >= L; --i) {
        lmax[i] = max(lmax[i + 1], pre[mid] - pre[i - 1]);
        lans[i] = max(lans[i + 1], 0ll) + a[i];
    }
    for (int i = mid; i >= L; --i) lans[i] = max(lans[i], lans[i + 1]);
    for (int i = mid + 1; i <= R; ++i) {
        rmax[i] = max(rmax[i - 1], pre[i] - pre[mid]);
        rans[i] = max(rans[i - 1], 0ll) + a[i];
    }
    for (int i = mid + 1; i <= R; ++i) rans[i] = max(rans[i], rans[i - 1]);
    prans[mid] = prmax[mid] = 0;
    for (int i = mid + 1; i <= R; ++i) {
        prans[i] = prans[i - 1] + rans[i];
        prmax[i] = prmax[i - 1] + rmax[i];
    }
    ll res = 0;
    int S = mid + 1, X = mid + 1, T = mid + 1;
    for (int i = mid; i >= L; --i) {
        while (S <= R && rans[S] <= lans[i]) ++S;
        while (X <= R && rmax[X] <= lans[i] - lmax[i]) ++X;
        if (S <= X) {
            res = (res + lans[i] * (S - mid - 1) % mod) % mod;
            res = (res + prans[X - 1] - prans[S - 1] + mod) % mod;
            // X ~ R
            T = max(T, X);
            while (T <= R && lmax[i] > rans[T] - rmax[T]) ++T;
            res = (res + prans[R] - prans[T - 1]) % mod;
            res = (res + prmax[T - 1] - prmax[X - 1] + lmax[i] * (T - X) % mod) % mod;
        } else {
            res = (res + lans[i] * (X - mid - 1) % mod) % mod;
            res = (res + prmax[S - 1] - prmax[X - 1] + lmax[i] * (S - X) % mod) % mod;
            // S ~ R
            T = max(T, S);
            while (T <= R && lmax[i] > rans[T] - rmax[T]) ++T;
            res = (res + prans[R] - prans[T - 1]) % mod;
            res = (res + prmax[T - 1] - prmax[S - 1] + lmax[i] * (T - S) % mod) % mod;
        }
    }
    ans = (ans + res) % mod;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int T;
    cin >> T;
    while (T--) {
        cin >> n;
        for (int i = 1; i <= n; ++i) {
            cin >> a[i];
            pre[i] = pre[i - 1] + a[i];
        }
        ans = 0;
        f[n + 1] = 0;
        pre[n + 1] = LLONG_MAX;
        stk[top = 1] = n + 1;
        for (int i = n; i; --i) {
            while (top && pre[i] >= pre[stk[top]]) --top;
            stk[++top] = i;
            f[top] = (f[top - 1] + pre[i] * (stk[top - 1] - stk[top]) % mod) % mod;
            int l = 1, r = top;
            while (l < r) {
                int mid = (l + r + 1) >> 1;
                if (pre[stk[mid]] > pre[i - 1]) l = mid;
                else r = mid - 1;
            }
            ans = (ans + pre[i - 1] * (stk[l] - i) % mod) % mod;
            ans = (ans + f[l]) % mod;
            ans = (ans - pre[i - 1] * (n - i + 1) % mod + mod) % mod;
        }
        cdq(1, n);
        cout << ans << '\n';
    }

    return 0;
}