题解:P13825 线段树 1.5

· · 题解

Analysis

首先问题可以变成一个初始全 0 的序列按要求做操作,把每次查询的答案加上 [l,r] 的原区间和,即 \frac{(r-l+1)(l+r)}{2}

然后考虑这个对初始全 0 序列做区间加和区间查询的操作,如果 n 很小,那就可以直接用普通的线段树(P3372)来解决。

假设我们能把这个 n 很大的线段树建出来,注意到初始时所有节点的值(包括区间和以及 lazy tag)都是 0。一次修改和查询操作只会访问到树上 O(\log n) 个节点。显然只有被访问过的节点才有可能被修改值。所以 m 次操作结束后,这棵树最多只有 O(m \log n) 个节点的值是非 0 的。

我们只维护非 0 的节点。对于一个节点,如果它的某个孩子指针为空,则表示该孩子和子树的值均为 0。在最开始,只创建根节点。注意到每次线段树向下递归时总会先进行一次 pushdown,所以只需要在 pushdown 时提前把空的左右孩子创建出来,就可以保证访问到的节点总是非空的。并且因为一共只进行了 O(m \log n) 次 pushdown,所以创建的总节点数也是 O(m \log n)

注意到这样的写法仍然创建了一些仅用于查询的(本身没被修改过,即值为 0)的节点。更精细的写法可以只创建真正被修改过的节点,在查询时手动跳过值为 0 的递归。但是在大部分情况下(查询数和修改数同阶),二者创建的节点数差别并不大,本文的写法已经足够了。

注意要开 unsiged long long。

Code

#include <bits/stdc++.h>

using ll = unsigned long long;

struct Node {
  Node *ls, *rs;
  int l, r;
  ll val, tag;

  Node(int L, int R) : l(L), r(R), val(0), tag(0) {
    ls = rs = nullptr;
  }

  void pushup() { val = ls->val + rs->val; }
  bool inRange(int L, int R) { return L <= l && r <= R; }
  bool outRange(int L, int R) { return r < L || R < l; }
  inline void makeTag(ll x) {
    val += (r - l + 1) * x;
    tag += x;
  }
  inline void pushdown() {
    if (!ls) {
      int mid = (l + r) >> 1;
      ls = new Node(l, mid);
      rs = new Node(mid + 1, r);
    }
    if (tag) {
      ls->makeTag(tag);
      rs->makeTag(tag);
      tag = 0;
    }
  }

  void upd(int L, int R, ll k) {
    if (inRange(L, R)) {
      makeTag(k);
    } else if (!outRange(L, R)) {
      pushdown();
      ls->upd(L, R, k);
      rs->upd(L, R, k);
      pushup();
    }
  }

  ll qry(int L, int R) {
    if (inRange(L, R)) {
      return val;
    } else if (outRange(L, R)) {
      return 0;
    } else {
      pushdown();
      return ls->qry(L, R) + rs->qry(L, R);
    }
  }
};

int main() {
  int n, q;
  std::cin >> n >> q;
  auto rot = new Node(1, n);
  for (int op, l, r; q; --q) {
    std::cin >> op >> l >> r;
    if (op == 1) {
      ll k;
      std::cin >> k;
      rot->upd(l, r, k);
    } else {
      std::println("{}", rot->qry(l, r) + [](ll L, ll R) -> ll {
        return (R - L + 1) * (R + L) / 2;
      }(l, r));
    }
  }
}