题解:CF2064F We Be Summing

· · 题解

前言

看完日麻新生はるか的光速AK,我大为震撼,于是写一篇题解加深对这题的印象。

正文

首先我们分析一下题目,题目要我们找序列 a 长度为 m 的子串 b,要求字串具有性质:

\min(b_1, b_2, \dots, b_i) + \max(b_{i+1}, b_{i+2}, \dots, b_m) = k

我们很快发现,对于一个任意的子串 b,随着 i 的增大,前面前缀求取最小值的部分单调不增,后面后缀求取最大值的部分也是单调不增,所以 b 的合法切分一定是 b 上连续的一段。

我们考虑一个数什么时候能成为最大值或者最小值,也就是一个数在什么时候会对和产生贡献。这个贡献一定是一个区间,也就是说,在这个区间内,这个数一定是最大/最小的。如果一个数最小值的贡献区间与另一个数的最大贡献区间有重合,那么重合部分就是合法的切分。

上述部分需要我们快速处理出贡献位置,这部分可以用单调栈 O(n) 处理。接下来就只剩下如何快速计算合法方案的数量了。我们分析一下上面的图片,我们只需要选上了 a_1a_2 及其内的子串,那么就一定是合法的,所以答案就是 a_1 合法贡献区间左边的元素数量乘以 a_2 合法贡献区间右边的元素。这个过程可以用树状数组加速。

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5+6;
int n, k;
int a[N];
int l1[N], l2[N], r1[N], r2[N];
int stk[N], tp;
vector<int> vec[N];
int cnt[N];
void solve() {
    int res = 0;
    cin>>n>>k;
    for(int i=1; i<=n; i++) cin>>a[i];
    stk[tp=0] = 0;
    for(int i=1; i<=n; i++) {
        while(tp && a[stk[tp]] > a[i]) --tp;
        l1[i] = stk[tp] + 1;
        stk[++tp] = i;
    }
    stk[tp=0] = 0;
    for(int i=1; i<=n; i++) {
        while(tp && a[stk[tp]] <= a[i]) --tp;
        l2[i] = stk[tp] + 1;
        stk[++tp] = i;
    }
    stk[tp=0] = n+1;
    for(int i=n; i; i--) {
        while(tp && a[stk[tp]] >= a[i]) --tp;
        r1[i] = stk[tp] - 1;
        stk[++tp] = i;
    }
    stk[tp=0] = n+1;
    for(int i=n; i; i--) {
        while(tp && a[stk[tp]] < a[i]) --tp;
        r2[i] = stk[tp] - 1;
        stk[++tp] = i;
    }
    for(int i=1; i<=n; i++) vec[a[i]].push_back(i);
    for(int l = k-n, r=n; l<=n; ++l, --r) {
        auto A = vec[l], B = vec[r];
        while(!A.empty()) {
            while(!B.empty() && B.back() > A.back()) {
                int val = r2[B.back()]-B.back()+1;
                for(int i = l2[B.back()];i<=n;i+=i&-i) cnt[i] += val;
                B.pop_back();
            }
            int x = A.back(); A.pop_back();
            int v = 0;
            for(int i=min(r1[x]+1, n); i; i-=i&-i) v += cnt[i];
            res += v*(x-l1[x]+1);
        }
        int m = (int)B.size(); B = vec[r];
        while(B.size() > m) {
            for(int i = l2[B.back()];i<=n;i+=i&-i) cnt[i] = 0;
            B.pop_back();
        }
    }
    for(int i=1; i<=n; i++) vec[i].clear();
    cout<<res<<"\n";
}

signed main(){
    ios::sync_with_stdio(false),cin.tie(nullptr),cout.tie(nullptr);
    int T = 1;
    cin>>T;
    for(int i=1; i<=T; i++) solve();
    return 0;
}