P8797 [蓝桥杯 2022 国 A] 三角序列

· · 题解

upd on 2023/10/22:修正了柿子。

很显然要二分 h,那么问题就是快速处理中间高度小于 h 的圆圈个数。

可以把一个区间分成三部分:

左端点所在三角形 L,右端点所在三角形 R,中间的三角形 L+1 \sim R-1

然后分类讨论:

说明左右端点在同个三角形内,我们可以快速计算高度小于 h 的圆圈个数。

int calc(int i, int l, int r, int h) {
    int v1, v2;
    if (b[i] == 0) v1 = l, v2 = r;
    if (b[i] == 1) v1 = a[i] - r + 1, v2 = a[i] - l + 1;

    if (h <= v1) return h * (r - l + 1);
    if (h > v2)  return (v2 + v1) * (r - l + 1) / 2;
    return (v2 - h) * h + (h + v1) * (h - v1 + 1) / 2;
}

比较麻烦,左右两边可以通过 calc 函数快速计算,问题在于中间的完整三角形。

直接计算符合条件的圆圈个数比较难,我们可以用总个数-不符合条件的个数。

设中间完整三角形高度大于h的高度为 h_1, h_2, \cdots, h_k,那么不符合条件圆圈个数等于:

\begin{aligned} & \frac{1}{2} \sum\limits_{i = 1}^k (h_i - h) (h_i - h + 1) \\&= \frac{1}{2} \sum\limits_{i = 1}^k (h_i^2 - 2h_ih + h^2 + h_i - h) \end{aligned}

S_1= \sum_{i=1}^{k}h_i , S_2= \sum_{i=1}^{k}h_i^{2}

那么原式 = \frac{1}{2}[S_2 - (2h - 1)S_1 + (h^2 - h)k]

显然可以主席树维护,快速查询 S_1 , S_2 , k,进而计算出符合条件的圆圈个数。

#include<cmath>
#include<cstdio>
#include<algorithm>
#define int long long
using namespace std;
const int M = 1e6 + 10;
const int inf = 1000000;

int L[M], R[M];
int pre[M], cnt_pre[M], a[M], b[M], n, m;

int calc(int i, int l, int r, int h) {
    int v1, v2;
    if (b[i] == 0) v1 = l, v2 = r;
    if (b[i] == 1) v1 = a[i] - r + 1, v2 = a[i] - l + 1;

    if (h <= v1) return h * (r - l + 1);
    if (h > v2)  return (v2 + v1) * (r - l + 1) / 2;
    return (v2 - h) * h + (h + v1) * (h - v1 + 1) / 2;
}

struct PersistentTree {
    int lc, rc, sum, sqs, cnt;
    //sum[k,k]=k sqs[k,k]=k*k cnt[k,k]=1
}tr[M << 3];
int rt[M], cntv;
void pushup(int k) {
    tr[k].sum = tr[tr[k].lc].sum + tr[tr[k].rc].sum;
    tr[k].sqs = tr[tr[k].lc].sqs + tr[tr[k].rc].sqs;
    tr[k].cnt = tr[tr[k].lc].cnt + tr[tr[k].rc].cnt;
}
void update(int &k, int pre, int l, int r, int pos, int h) {
    tr[k = ++cntv] = tr[pre];
    if (l == r) {
        tr[k].sum += h;
        tr[k].sqs += h * h;
        tr[k].cnt += 1;
        return;
    }
    int mid = (l + r) >> 1;
    if (pos <= mid) update(tr[k].lc, tr[pre].lc, l, mid, pos, h);
    else update(tr[k].rc, tr[pre].rc, mid + 1, r, pos, h);
    pushup(k);
}
int query_sum(int k, int pre, int L, int R, int l, int r) {
    if (L <= l and r <= R) return tr[k].sum - tr[pre].sum;
    int mid = (l + r) >> 1, ans = 0;
    if (L <= mid) ans += query_sum(tr[k].lc, tr[pre].lc, L, R, l, mid);
    if (R > mid)  ans += query_sum(tr[k].rc, tr[pre].rc, L, R, mid + 1, r);
    return ans;
}
int query_sqs(int k, int pre, int L, int R, int l, int r) {
    if (L <= l and r <= R) return tr[k].sqs - tr[pre].sqs;
    int mid = (l + r) >> 1, ans = 0;
    if (L <= mid) ans += query_sqs(tr[k].lc, tr[pre].lc, L, R, l, mid);
    if (R > mid)  ans += query_sqs(tr[k].rc, tr[pre].rc, L, R, mid + 1, r);
    return ans;
}
int query_cnt(int k, int pre, int L, int R, int l, int r) {
    if (L <= l and r <= R) return tr[k].cnt - tr[pre].cnt;
    int mid = (l + r) >> 1, ans = 0;
    if (L <= mid) ans += query_cnt(tr[k].lc, tr[pre].lc, L, R, l, mid);
    if (R > mid)  ans += query_cnt(tr[k].rc, tr[pre].rc, L, R, mid + 1, r);
    return ans;
}

signed main() {
    scanf("%lld%lld", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%lld%lld", &a[i], &b[i]);
        update(rt[i], rt[i - 1], 1, inf, a[i], a[i]);
        pre[i] = pre[i - 1] + a[i];
        cnt_pre[i] = cnt_pre[i - 1] + a[i] * (a[i] + 1) / 2;
        L[i] = pre[i - 1] + 1, R[i] = pre[i];
    }
    for (int i = 1, l, r, v; i <= m; i++) {
        scanf("%lld%lld%lld", &l, &r, &v);
        int ll = lower_bound(pre + 1, pre + 1 + n, l) - pre;
        int rr = lower_bound(pre + 1, pre + 1 + n, r) - pre;

        int lb = 1, rb = inf, ans = -1;
        while (lb <= rb) {
            int mid = (lb + rb) >> 1;
            int res;
            if (ll == rr) res = calc(ll, l - L[ll] + 1, r - L[ll] + 1, mid);
            else {
                int tmp_sum = query_sum(rt[rr - 1], rt[ll], mid, inf, 1, inf) - query_sum(rt[rr - 1], rt[ll], mid, mid, 1, inf);//S1
                int tmp_sqs = query_sqs(rt[rr - 1], rt[ll], mid, inf, 1, inf) - query_sqs(rt[rr - 1], rt[ll], mid, mid, 1, inf);//S2
                int tmp_cnt = query_cnt(rt[rr - 1], rt[ll], mid, inf, 1, inf) - query_cnt(rt[rr - 1], rt[ll], mid, mid, 1, inf);//k
                res = (tmp_sqs - 2 * tmp_sum * mid + tmp_cnt * mid * mid + tmp_sum - tmp_cnt * mid) / 2;
                res = cnt_pre[rr - 1] - cnt_pre[ll] - res;
                res += calc(ll, l - L[ll] + 1, R[ll] - L[ll] + 1, mid) + calc(rr, 1, r - L[rr] + 1, mid);
            }
            if (res >= v) ans = mid, rb = mid - 1;
            else lb = mid + 1;
        }
        printf("%lld\n", ans);
    }
    return 0;
}