题解 P5459 【[BJOI2016]回转寿司】

· · 题解

Description

给定一个长度为 n 的序列 \{a\},现在要从中选出一段连续子序列 [l,r],使得 L \leq \sum\limits_{i=l}^r a_i \leq R,求方案数。

(1 \leq n \leq 10^5, | a_i | \leq 10^5, 1 \leq L,R \leq 10^9)

Solution

我们可以枚举 r = 1 \sim n,求出对于每个 r 有多少 l 符合条件,累加即是答案。

先预处理出前缀和数组 pre,那么 \sum\limits_{i=l}^r a_i 的值为 pre_r - pre_{l - 1},当且仅当 L \leq pre_r - pre_{l - 1} \leq Rl 符合条件。

变形一下这个式子,可得 pre_r - R \leq pre_{l - 1} \leq pre_r -L

所以我们只需要找到有多少个 pre_{l - 1} \left(l \in [1,r] \right)[pre_r - R,pre_r - L] 区间内。单点修改,区间查询,可以离散化后用树状数组维护,这里我使用了动态开点线段树,原理相同(不要忘记初始时插入 pre_0 = 0)。时间复杂度为 O(n \log n)

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;

template <class T>
inline void read(T &x) {
    x = 0;
    char c = getchar();
    bool f = 0;
    for (; !isdigit(c); c = getchar()) f ^= c == '-';
    for (; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
    x = f ? -x : x;
}

template <class T>
inline void write(T x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    T y = 1;
    int len = 1;
    for (; y <= x / 10; y *= 10) ++len;
    for (; len; --len, x %= y, y /= 10) putchar(x / y + 48);
}

const int MAXN = 1e5, LOG = 34;
const LL MAXM = 1e10;
int n, l, r;
LL ans, pre[MAXN + 5];
struct SegmentTree {//动态开点线段树
    int root, tot, sum[MAXN * LOG * 4 + 5], lson[MAXN * LOG * 4 + 5], 
    rson[MAXN * LOG * 4 + 5];

    inline void pushUp(int x) {
        sum[x] = sum[lson[x]] + sum[rson[x]];
    }
    void update(int &x, LL q, LL p, LL l = -MAXM, LL r = MAXM) {
        if (!x) x = ++tot;
        if (l == r) {
            sum[x] += p;
            return;
        }
        LL mid = (l + r) >> 1;
        if (q <= mid) update(lson[x], q, p, l, mid);
        else update(rson[x], q, p, mid + 1, r);
        pushUp(x);
    }
    int query(int &x, LL ql, LL qr, LL l = -MAXM, LL r = MAXM) {
        if (!x) x = ++tot;
        if (ql <= l && qr >= r) return sum[x];
        LL res = 0, mid = (l + r) >> 1;
        if (ql <= mid) res += query(lson[x], ql, qr, l, mid);
        if (qr > mid) res += query(rson[x], ql, qr, mid + 1, r);
        return res;
    }
} tr;

int main() {
    read(n), read(l), read(r);
    for (int x, i = 1; i <= n; ++i) {
        read(x);
        pre[i] = pre[i - 1] + x;//预处理前缀和
    }
    tr.update(tr.root, pre[0], 1);//插入 pre[0]
    for (int i = 1; i <= n; ++i) {//枚举右端点
        ans += tr.query(tr.root, pre[i] - r, pre[i] - l);
        //l - 1 在区间 [0,r) 内
        tr.update(tr.root, pre[i], 1);//插入 pre[i]
    }
    write(ans);
    putchar('\n');
    return 0;
}