题解:P9775 [HUSTFC 2023] 广义线段树

· · 题解

广义线段树,从名字就能看出它进行单点修改时的流程。

从根节点出发,如果当前节点对应的区间包含要修改的点,则修改当前节点,并向下递归缩小区间。

所以当我们处理到 b_i 时,我们就只要修改节点 i+n-1 到根节点 1 的路径上的点并更新答案就行了。

但是从题目中的图就可以看出,这棵广义线段树不一定平衡,暴力往上跳的单次修改的时间复杂度可能为 O(n),肯定超时,所以我们用树链剖分来修改和查询,这样可以将单次操作的时间降为 O(\log^2 n)

AC Code

#include <iostream>
using namespace std;
using ll = long long;
const int N = 5e5 + 5;
const int MOD = 998244353;
int n;
int lpi[N], rpi[N];
int a[N], b[N];
int fa[N << 1], dep[N << 1], siz[N << 1], son[N << 1];
ll val[N << 1];
void dfs1(int u) {
  dep[u] = dep[fa[u]] + 1;
  siz[u] = 1;
  if (u >= n)
    return;
  for (int v : {lpi[u], rpi[u]}) {
    dfs1(v);
    siz[u] += siz[v];
    if (siz[v] > siz[son[u]])
      son[u] = v;
  }
}
int dfn[N << 1], top[N << 1], tot = 0;
ll ans = 0;
ll dfs2(int u, int f) {
  dfn[u] = ++tot;
  top[u] = f;
  if (u >= n) {
    (ans += (val[tot] = a[u - n + 1])) %= MOD;
    return val[tot];
  }
  val[dfn[u]] = dfs2(son[u], f);
  if (son[u] == lpi[u])
    (val[dfn[u]] *= dfs2(rpi[u], rpi[u])) %= MOD;
  else
    (val[dfn[u]] *= dfs2(lpi[u], lpi[u])) %= MOD;
  (ans += val[dfn[u]]) %= MOD;
  return val[dfn[u]];
}
#define lp (p << 1)
#define rp (p << 1 | 1)
struct SegTree {
  struct Node {
    int l, r;
    ll sum, tag;
  } tr[N << 3];
  void pushup(int p) { tr[p].sum = (tr[lp].sum + tr[rp].sum) % MOD; }
  void pushdown(int p) {
    if (tr[p].tag == 1)
      return;
    ll tag = tr[p].tag;
    (tr[lp].sum *= tag) %= MOD;
    (tr[rp].sum *= tag) %= MOD;
    (tr[lp].tag *= tag) %= MOD;
    (tr[rp].tag *= tag) %= MOD;
    tr[p].tag = 1;
  }
  void build(int p, int l, int r) {
    tr[p] = {l, r, 1, 1};
    if (l == r) {
      tr[p].sum = val[l];
      return;
    }
    int mid = (l + r) >> 1;
    build(lp, l, mid), build(rp, mid + 1, r);
    pushup(p);
  }
  void update(int p, int l, int r, ll k) {
    if (l <= tr[p].l && tr[p].r <= r) {
      (tr[p].sum *= k) %= MOD;
      (tr[p].tag *= k) %= MOD;
      return;
    }
    pushdown(p);
    if (l <= tr[lp].r)
      update(lp, l, r, k);
    if (tr[rp].l <= r)
      update(rp, l, r, k);
    pushup(p);
  }
  ll query(int p, int l, int r) {
    if (l <= tr[p].l && tr[p].r <= r)
      return tr[p].sum;
    pushdown(p);
    ll res = 0;
    if (l <= tr[lp].r)
      (res += query(lp, l, r)) %= MOD;
    if (tr[rp].l <= r)
      (res += query(rp, l, r)) %= MOD;
    return res;
  }
} seg;
#undef lp
#undef rp
ll queryPath(int p) {
  ll res = 0;
  p += n - 1;
  while (p) {
    (res += seg.query(1, dfn[top[p]], dfn[p])) %= MOD;
    p = fa[top[p]];
  }
  return res;
}
void updatePath(int p, ll k) {
  p += n - 1;
  while (p) {
    seg.update(1, dfn[top[p]], dfn[p], k);
    p = fa[top[p]];
  }
}
void changePoint(int p) {
  (ans += (b[p] - 1) * queryPath(p)) %= MOD;
  cout << ans << ' ';
  updatePath(p, b[p]);
}
int main() {
  ios::sync_with_stdio(0);
  cin.tie(0), cout.tie(0);
  cin >> n;
  for (int i = 1; i <= n; i++)
    cin >> a[i];
  for (int i = 1; i <= n; i++)
    cin >> b[i];
  for (int i = 1; i < n; i++) {
    cin >> lpi[i] >> rpi[i];
    fa[lpi[i]] = fa[rpi[i]] = i;
  }
  dfs1(1), dfs2(1, 1);
  seg.build(1, 1, 2 * n - 1);
  for (int i = 1; i <= n; i++)
    changePoint(i);
  return 0;
}