题解:P10169 [DTCPC 2024] mex,min,max
首先这个
先假设
然后是
此时的条件为
我这里直接算的贡献区间。假如
否则我们考虑笛卡尔树上遍历轻子树,也就是遍历
然后再持久化线段树上查位置
对于
复杂度
比较考验码力。
#include <bits/stdc++.h>
using namespace std;
const int N = 500005;
// const int V = ;
// const int mod = ;
typedef unsigned us;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair <int, int> pii;
typedef vector <int> vi;
typedef vector <pii> vpi;
typedef vector <ll> vl;
template <class T> using pq = priority_queue <T>;
template <class T> using pqg = priority_queue <T, vector <T>, greater <T> >;
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define repr(i, a, b) for (int i = (a); i < (b); ++i)
#define per(i, a, b) for (int i = (a); i >= (b); --i)
#define perr(i, a, b) for (int i = (a); i > (b); --i)
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define pb push_back
template <class T1, class T2> inline void ckmn(T1 &a, T2 b) { (a > b) && (a = b, 0); }
template <class T1, class T2> inline void ckmx(T1 &a, T2 b) { (a < b) && (a = b, 0); }
namespace IO {
// char buf[1 << 23], *p1 = buf, *p2 = buf;
// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
template <class T> void rd(T &a, unsigned c = 0) {
while (c = getchar(), c < 48 || c > 57);
for (a = 0; c >= 48 && c <= 57; c = getchar()) a = (a << 3) + (a << 1) + (c ^ 48);
}
template <class T> void wrt(T x) { if (x > 9) wrt(x / 10); putchar(x % 10 ^ 48); }
} using IO::rd; using IO::wrt;
int n, k, a[N];
int st1[N][20], st2[N][20], lg2[N];
int stq1(int l, int r) {
int k = lg2[r - l + 1];
return min(st1[l][k], st1[r - (1 << k) + 1][k]);
}
int stq2(int l, int r) {
int k = lg2[r - l + 1];
return max(st2[l][k], st2[r - (1 << k) + 1][k]);
}
namespace smt1 {
struct Seg {
int ls, rs, mx;
#define ls(p) tr[p].ls
#define rs(p) tr[p].rs
#define mx(p) tr[p].mx
} tr[N * 25];
#define M (L + R >> 1)
int tot;
int build(int L, int R) {
if (L == R) {
int p = ++tot;
mx(p) = n + 1;
return p;
}
int p = ++tot;
mx(p) = n + 1;
ls(p) = build(L, M), rs(p) = build(M + 1, R);
return p;
}
int update(int p, int x, int k, int L, int R) {
if (L == R) {
int q = ++tot; tr[q] = tr[p];
mx(q) = k;
return q;
}
int q = ++tot; tr[q] = tr[p];
if (x <= M) ls(q) = update(ls(p), x, k, L, M);
else rs(q) = update(rs(p), x, k, M + 1, R);
mx(q) = max(mx(ls(q)), mx(rs(q)));
return q;
}
int query(int p, int l, int r, int L, int R) {
if (l <= L && r >= R) return mx(p);
if (r <= M) return query(ls(p), l, r, L, M);
if (l > M) return query(rs(p), l, r, M + 1, R);
return max(query(ls(p), l, r, L, M), query(rs(p), l, r, M + 1, R));
}
}
namespace smt2 {
struct Seg {
int ls, rs, mn;
#define mn(p) tr[p].mn
} tr[N * 25];
int tot;
int build(int L, int R) {
if (L == R) {
int p = ++tot;
mn(p) = 0;
return p;
}
int p = ++tot;
mn(p) = 0;
ls(p) = build(L, M), rs(p) = build(M + 1, R);
return p;
}
int update(int p, int x, int k, int L, int R) {
if (L == R) {
int q = ++tot; tr[q] = tr[p];
mn(q) = k;
return q;
}
int q = ++tot; tr[q] = tr[p];
if (x <= M) ls(q) = update(ls(p), x, k, L, M);
else rs(q) = update(rs(p), x, k, M + 1, R);
mn(q) = min(mn(ls(q)), mn(rs(q)));
return q;
}
int query(int p, int l, int r, int L, int R) {
if (l <= L && r >= R) return mn(p);
if (r <= M) return query(ls(p), l, r, L, M);
if (l > M) return query(rs(p), l, r, M + 1, R);
return min(query(ls(p), l, r, L, M), query(rs(p), l, r, M + 1, R));
}
}
int rt1[N], rt2[N];
int f[N], g[N];
stack <int> stk;
ll play(int s, int t) {
while (stk.size()) stk.pop();
rep(i, s, t) f[i] = s - 1, g[i] = t + 1;
rep(i, s, t) {
while (stk.size() && a[i] > a[stk.top()]) g[stk.top()] = i, stk.pop();
if (stk.size()) f[i] = stk.top();
stk.push(i);
}
ll R = 0;
rep(i, s, t) {
int l = f[i] + 1, r = g[i] - 1;
if (a[i] <= k) { R += 1ll * (i - l + 1) * (r - i + 1); continue; }
if (i - l < r - i) {
rep(j, l, i) {
int pos = max(smt1::query(rt1[j], 0, a[i] - k - 1, 0, n), i);
if (pos > r) continue;
R += r - pos + 1;
}
} else {
rep(j, i, r) {
int pos = min(smt2::query(rt2[j], 0, a[i] - k - 1, 0, n), i);
if (pos < l) continue;
R += pos - l + 1;
}
}
}
return R;
}
void solve() {
rd(n), rd(k);
rep(i, 1, n) rd(a[i]), st1[i][0] = st2[i][0] = a[i];
rep(i, 2, n) lg2[i] = lg2[i >> 1] + 1;
rep(j, 1, 19) {
rep(i, 1, n - (1 << j) + 1) {
st1[i][j] = min(st1[i][j - 1], st1[i + (1 << j - 1)][j - 1]);
st2[i][j] = max(st2[i][j - 1], st2[i + (1 << j - 1)][j - 1]);
}
}
rt1[n + 1] = smt1::build(0, n);
rt2[0] = smt2::build(0, n);
per(i, n, 1) rt1[i] = smt1::update(rt1[i + 1], a[i], i, 0, n);
rep(i, 1, n) rt2[i] = smt2::update(rt2[i - 1], a[i], i, 0, n);
ll ans = play(1, n);
int lst = 0;
rep(i, 1, n) {
if (!a[i]) {
if (i != lst + 1) ans -= play(lst + 1, i - 1);
lst = i;
continue;
}
int l = lst + 1, r = i;
while (l < r) {
int mid = (l + r) >> 1;
if (stq2(mid, i) - stq1(mid, i) <= k) r = mid;
else l = mid + 1;
}
ans += i - l + 1;
}
if (lst != n) ans -= play(lst + 1, n);
wrt(ans);
}
int main() {
int T = 1;
if (0) rd(T);
while (T--) solve();
}