题解:CF1998E2 Eliminating Balls With Merging (Hard Version)

· · 题解

E1 题解

这次题目是问你有多少个 i 可以消掉一个 [1,k] 的前缀。

首先,我们和刚才一样,计算每个 a_i 可以删除的最大区间。如果最大区间的左端点不是 1,显然它对答案没有任何贡献。

我们假设算出来的右端点是 R,那么 a_i 能删除的前缀的右端点就形成了一个 [x,R] 的区间。

我们发现,这个 x 也是可以二分的。

关于这个二分,我们的 check 就是和前面那样计算一下 a_i[1,x] 范围内能消除的最大区间。如果最大区间也是 [1,x] 的话,就可行。

二分完之后,我们只要用差分给 [x,R] 区间加一就好了。

时间复杂度又多了一个 \log,能 3.9s 卡过去。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const ll mod = 1e9 + 7;
const int N = 200005;
const int INF = 0x3f3f3f3f;
int a[N], n;
ll s[N];
ll sum(int l, int r) {
    return s[r] - s[l - 1];
}
int st[N][30];
void st_init() {
    for (int i = 1; i <= n; i++) st[i][0] = a[i];
    int p = __lg(n);
    for (int k = 1; k <= p; k++) {
        for (int s = 1; s + (1 << k) <= n + 1; s++) {
            st[s][k] = max(st[s][k - 1], st[s + (1 << (k - 1))][k - 1]);
        }
    }
}
int mx(int l, int r) {
    int k = __lg(r - l + 1);
    int x = max(st[l][k], st[r - (1 << k) + 1][k]);
    return x;
}
int D[N];
bool check(int i, int x) {
    int l = i, r = i;
    ll res = a[i], lr = a[i];
    while (1) {
        int L = 1, R = l - 1, p = l;
        while (L <= R) {
            int mid = L + R >> 1;
            if (mx(mid, l - 1) <= res) {
                p = mid;
                R = mid - 1;
            } else {
                L = mid + 1;
            }
        }
        if (p <= l - 1) {
            res += sum(p, l - 1);
            l = p;
        }
        L = r + 1, R = x, p = r;
        while (L <= R) {
            int mid = L + R >> 1;
            if (mx(r + 1, mid) <= res) {
                p = mid;
                L = mid + 1;
            } else {
                R = mid - 1;
            }
        }
        if (r + 1 <= p) {
            res += sum(r + 1, p);
            r = p;
        }
        if (res == lr) break;
        lr = res;
        if (l == 1 && r == x) break;
    }
    if (l == 1 && r == x) return true;
    return false;
}
int main() {
    int _;
    scanf("%d", &_);
    while (_--) {
        scanf("%d%*d", &n);
        for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
        for (int i = 1; i <= n; i++) s[i] = s[i - 1] + a[i];
        for (int i = 1; i <= n; i++) D[i] = 0;
        st_init();
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            int l = i, r = i;
            ll res = a[i], lr = a[i];
            while (1) {
                int L = 1, R = l - 1, p = l;
                while (L <= R) {
                    int mid = L + R >> 1;
                    if (mx(mid, l - 1) <= res) {
                        p = mid;
                        R = mid - 1;
                    } else {
                        L = mid + 1;
                    }
                }
                if (p <= l - 1) {
                    res += sum(p, l - 1);
                    l = p;
                }
                L = r + 1, R = n, p = r;
                while (L <= R) {
                    int mid = L + R >> 1;
                    if (mx(r + 1, mid) <= res) {
                        p = mid;
                        L = mid + 1;
                    } else {
                        R = mid - 1;
                    }
                }
                if (r + 1 <= p) {
                    res += sum(r + 1, p);
                    r = p;
                }
                if (res == lr) break;
                lr = res;
                if (l == 1 && r == n) break;
            }
            if (l == 1) {
                int left = i, right = r, rr = r;
                while (left <= right) {
                    int mid = left + right >> 1;
                    if (check(i, mid)) {
                        right = mid - 1;
                        rr = mid;
                    } else {
                        left = mid + 1;
                    }
                }
                D[rr]++;
                D[r + 1]--;
            }
        }
        for (int i = 1; i <= n; i++) D[i] += D[i - 1];
        for (int i = 1; i <= n; i++) printf("%d ", D[i]);
        printf("\n");
    }
    return 0;
}