题解:P10603 BZOJ4372 烁烁的游戏

· · 题解

P10603 烁烁的游戏 题解

题意:给定一棵 n 个点带权的树(边权均为 1),初始所有点权为 0。给出 m 次操作:询问一个点的点权,或修改距离一个点不超过 d 的所有点的点权。

前置知识

建议先完成以下知识点的题目再来挑战本题。

树状数组:见 P3374 及 P3368。

点分树:见 P6329。

分析

本题与 P6329 的任务要求和解法都较为相似,不过该题为单点修改、区间查询,而本题为区间修改,单点查询。这里的区间实际上指的是树上连通块。

涉及维护树上连通块信息的问题优先考虑使用点分树解决,按照点分树的常见套路,在每个点上建立数据结构维护信息,在修改/查询时遍历节点在点分树上的祖先并修改/查询这些点上的数据结构信息。通常实际上需要建立多个数据结构,用容斥原理防止部分答案被重复计算。

对于本题,因为我们不便直接修改整个连通块内所有节点的信息,所以我们使用「标记」的思想:进行修改操作时仅在输入的节点 x 处打上一个标记,并使得在查询时能够根据所有的标记算出累计的影响。点分树上遍历祖先的操作恰好达到了这一目的。

思路

在每个点上建立两个树状数组 c_{i_0},c_{i_1} 维护节点 i 的点分子树中的节点受到修改的影响。这里我们利用树状数组维护差分数组的技巧,c_{i_0} 的前缀和 \sum_{j=0}^{k}c_{i_0,j} 维护的是节点 i 的点分子树中与节点 i 在原树上的距离不超过 k 的节点受到的累计影响,c_{i_1} 的前缀和 \sum_{j=0}^{k}c_{i_1,j} 维护的是节点 i 的点分子树中与节点 i点分父节点在原树上的距离不超过 k 的节点受到的累计影响。

进行修改操作时,首先对 c_{x_0} 进行区间修改,在 [0,k] 的值域上加上 w(由于是差分数组,实际只需进行 2 次单点修改即可)。然后遍历 x 的所有点分祖先,对祖先 uc_{u_0}[0,k-\text{dis}(x,u)] 的值域上加上 w,其中 \text{dis}(x,u) 表示 xu 在原树上的距离。注意到此时重复计算了 ux 方向上的子树内那一部分的答案,所以需要在 c_{x_1}[0,k-\text{dis}(x,u)] 值域上减去 w

查询操作则更简单,直接遍历 x 的所有点分祖先 u,累加 c_{u_0} 的前缀和 \text{ask}(c_{u_0},\text{dis}(u, x))c_{u_1} 的前缀和 \text{ask}(c_{u_1},\text{dis}(fa_u, x)) 即可。其中 fa_u 表示 u 在点分树上的父亲。

代码

#include <bits/stdc++.h>
using namespace std;

const int maxn = 1e5 + 5;

int n, m, x, k, w, ans;
char s[3];
int sum, siz[maxn], weight[maxn], root, fa[maxn];
bool vis[maxn];
vector<int> g[maxn];
vector<int> c[maxn][2];
int d[maxn], f[maxn][20];

void dfs(int u) {
    for (int v : g[u]) {
        if (d[v]) continue;
        d[v] = d[u] + 1;
        f[v][0] = u;
        for (int i = 1; i <= 18; ++i) 
            f[v][i] = f[f[v][i - 1]][i - 1];
        dfs(v);
    }
}

int lca(int x, int y) {
    if (d[x] < d[y]) swap(x, y);
    for (int i = 18; i >= 0; --i)
        if (d[f[x][i]] >= d[y]) x = f[x][i];
    if (x == y) return x;
    for (int i = 18; i >= 0; --i)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}

int dist(int x, int y) {
    return d[x] + d[y] - 2 * d[lca(x, y)];
}

void add(vector<int> &c, int y, int d) {
    for (; y < c.size(); y += (y & -y)) {
        c[y] += d;
    }
}

void add(vector<int> &c, int l, int r, int d) {
    add(c, l + 1, d);
    add(c, r + 2, -d);
}

int ask(vector<int> &c, int y) {
    ++y;
    int res = 0;
    for (; y; y -= (y & -y)) 
        if (c.size() > y) res += c[y]; 
    return res;
}

void calcsize(int u, int fa) {
    siz[u] = 1;
    weight[u] = 0;
    for (int v : g[u]) {
        if (v == fa || vis[v]) continue;
        calcsize(v, u);
        siz[u] += siz[v];
        weight[u] = max(weight[u], siz[v]);
    }
    weight[u] = max(weight[u], sum - siz[u]);
    if (weight[root] > weight[u]) root = u;
}

void build(int u) {
    c[u][0].resize(siz[u] + 10);
    vis[u] = true;
    for (int v : g[u]) {
        if (vis[v]) continue;
        sum = siz[v];
        root = 0;
        calcsize(v, 0);
        calcsize(root ,0);
        fa[root] = u;
        c[root][1].resize(siz[root] + 10);
        build(root);
    }
}

int main() {
    scanf("%d %d", &n, &m);
    for (int i = 1, u, v; i < n; ++i) {
        scanf("%d %d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    d[1] = 1;
    dfs(1);
    sum = n;
    weight[root = 0] = 0x3fffffff;
    calcsize(1, 0);
    calcsize(root, 0);
    build(root);
    while (m--) {
        scanf("%s %d", s, &x);
        if (s[0] == 'Q') {
            ans = ask(c[x][0], 0);
            for (int p = x; fa[p]; p = fa[p]) {
                ans += ask(c[fa[p]][0], dist(fa[p], x));
                ans += ask(c[p][1], dist(fa[p], x));
            }
            printf("%d\n", ans);
        } else {
            scanf("%d %d", &k, &w);
            add(c[x][0], 0, k, w);
            for (int p = x; fa[p]; p = fa[p]) {
                int d = dist(fa[p], x);
                if (k < d) continue; 
                add(c[fa[p]][0], 0, k - d, w);
                add(c[p][1], 0, k - d, -w);
            }
        }
    }
    return 0;
}