势能线段树!

· · 题解

好好玩的一个题。

上来试图用线段树维护修改操作。如果区间的最小值超过了 v,那直接修改就可以了。否则,如果区间最小值小于 v,则向两边递归即可,这样的递归到最后一定能把某个小于 va_i 减去 v,这样的 a_i 以后每次操作都一定会参与。这里我们注意到一个小 trick,在模 2^{64} 意义下,我们的 a_i 直接减会变成一个巨大无比的数,而它同样满足“参与以后每次操作”的性质。势能上,每个 a_i 至多减法溢出一次,所以这部分的复杂度是一个 \log 的。

那么接下来我们需要解决的就只剩下最小值恰好等于 v 的情况了。按照套路,我们继续试图维护区间严格次小值(不妨假设有,若无则直接退出即可)。如果严格次小值超过两倍的 v,那么我们操作后严格次小值依然会大于最小值,直接修改即可(打一个“最小值以外的数统一减多少”的标记,注意如此一来还需要维护最小值出现次数来算贡献)。否则,严格次小值在两倍的 v 以内,注意到操作会使次小值至少减半,继续暴力递归即可。势能上,每个 a_i 至多减半 \log 次(另外其实减法溢出后不可能再次减半),所以这部分的复杂度是两个 \log 的。

代码很好写,但是作者比较菜所以调了半天。

#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
constexpr int N = 5e5 + 9;
struct {
  ull mn, mncnt, sc, tsub, tnmn, sum;
} tr[1 << 20];
int n, m;
void pushup(int x) {
  int l = x << 1, r = x << 1 | 1;
  if (tr[l].mn > tr[r].mn) swap(l, r);
  auto& [xmn, xmncnt, xsc, _x, __x, xsm] = tr[x];
  auto [lmn, lmncnt, lsc, _l, __l, lsm] = tr[l];
  auto [rmn, rmncnt, rsc, _r, __r, rsm] = tr[r];
  xsm = lsm + rsm, xmncnt = lmncnt, xmn = lmn;
  if (lmn == rmn) {
    if (xmncnt += rmncnt, !(xsc = min(lsc, rsc))) xsc = lsc | rsc;
  } else if (!(xsc = min(lsc, rmn)))
    xsc = lsc | rmn;
}
void settsub(int x, ull t, int L, int R) {
  tr[x].mn -= t, tr[x].tsub += t;
  tr[x].sum -= t * (R - L + 1);
  if (ull& v = tr[x].sc) v -= t;
}
void settnmn(int x, ull t, int L, int R) {
  tr[x].tnmn += t;
  tr[x].sum -= t * (R - L + 1 - tr[x].mncnt);
  if (ull& v = tr[x].sc) v -= t;
}
void setnsub(int x, ull t, ull mn, int L, int R) {
  (tr[x].mn == mn ? settnmn : settsub)(x, t, L, R);
}
int pushdown(int x, int L, int R) {
  int mid = (L + R) >> 1;
  if (ull& t = tr[x].tsub) {
    settsub(x << 1, t, L, mid);
    settsub(x << 1 | 1, t, mid + 1, R);
    t = 0;
  }
  if (ull& t = tr[x].tnmn) {
    if (tr[x].mn > tr[x << 1].mn || tr[x].mn > tr[x << 1 | 1].mn) exit(1);
    setnsub(x << 1, t, tr[x].mn, L, mid);
    setnsub(x << 1 | 1, t, tr[x].mn, mid + 1, R);
    t = 0;
  }
  return mid;
}
ull query(int x, int l, int r, int L, int R) {
  if (l == L && r == R) return tr[x].sum;
  int mid = pushdown(x, L, R);
  if (r <= mid)
    return query(x << 1, l, r, L, mid);
  else if (l > mid)
    return query(x << 1 | 1, l, r, mid + 1, R);
  else
    return query(x << 1, l, mid, L, mid) +
           query(x << 1 | 1, mid + 1, r, mid + 1, R);
}
void update(int x, ull v, int l, int r, int L, int R) {
  if (l == L && r == R) {
    if (!tr[x].sc || tr[x].mn > v) {
      if (tr[x].mn != v) settsub(x, v, L, R);
      return;
    } else if (tr[x].mn == v && tr[x].sc > (v << 1))
      return settnmn(x, v, L, R);
  }
  int mid = pushdown(x, L, R);
  if (r <= mid)
    update(x << 1, v, l, r, L, mid);
  else if (l > mid)
    update(x << 1 | 1, v, l, r, mid + 1, R);
  else {
    update(x << 1, v, l, mid, L, mid);
    update(x << 1 | 1, v, mid + 1, r, mid + 1, R);
  }
  pushup(x);
}
void build(int x, int L, int R) {
  if (L == R) {
    ull a;
    cin >> a;
    tr[x] = {a, 1, 0, 0, 0, a};
    return;
  }
  int mid = (L + R) >> 1;
  build(x << 1, L, mid);
  build(x << 1 | 1, mid + 1, R);
  pushup(x);
}
int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  cin >> n >> m, build(1, 1, n);
  ull ans = 0;
  for (int op, l, r, x; m; --m) {
    cin >> op >> l >> r, l ^= ans, r ^= ans;
    if (op == 1) {
      cin >> x, x ^= ans;
      update(1, x, l, r, 1, n);
    } else {
      ans = query(1, l, r, 1, n);
      cout << ans << '\n';
      ans &= 0xfffff;
    }
  }
  return cout << flush, 0;
}