P3313 [SDOI2014]旅行分块题解

· · 题解

提供一种简单的分块写法。

前置芝士:树剖

把树树剖一下,问题就变成了序列问题。关键是如何处理宗教。

很容易可以想到把每种宗教都维护一遍,然而使用线段树维护的空间复杂度是 \mathcal{O}(nc),会炸。

所以我们可以换一种方法,使用空间小的分块维护。将序列分成 \sqrt n 块,由于每个宗教的分块只有和与最大值不同,所以空间复杂度是 \mathcal{O}(c\sqrt n),时间复杂度是 \mathcal{O}(n\sqrt n\log n),可以通过本题。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5, P = 5e2 + 5;
struct edge {
    int v, next;
} e[N << 1];
int a[N], c[N], w1[N], w[N], pos[N], sum[N][P], maxn[N][P], L[N], R[N];
// sum[i][j] 是第j块宗教为i的和,maxn是第j块宗教为i的最大值
int head[N], fa[N], id[N], siz[N], top[N], depth[N], son[N];
int edge_cnt, n, q, u, v, x, cnt, awa, slen, sum1, sum2;
char opt[10];
inline void build() {
    slen = sqrt(n), awa = n / slen + (n % slen != 0);
    for (int id = 1; id <= awa; id ++)
        L[id] = (id - 1) * slen + 1, R[id] = id * slen;
    R[awa] = n;
    for (int id = 1; id <= awa; id ++)
        for (int j = L[id]; j <= R[id]; j ++)
            pos[j] = id, sum[c[j]][id] += a[j], maxn[c[j]][id] = max(maxn[c[j]][id], a[j]);
}
inline void clear(int id) {
    for (int i = L[id]; i <= R[id]; i ++)
        maxn[c[i]][id] = 0, sum[c[i]][id] = 0;
} // 暴力清空块的标记
inline void rebuild(int id) {
    for (int i = L[id]; i <= R[id]; i ++)
        maxn[c[i]][id] = max(maxn[c[i]][id], a[i]), sum[c[i]][id] += a[i];
} // 暴力统计块的标记
inline void updColor(int k, int x) {
    clear(pos[k]), c[k] = x, rebuild(pos[k]);
} // 暴力修改
inline void updValue(int k, int x) {
    clear(pos[k]), a[k] = x, rebuild(pos[k]);
} // 暴力修改
inline int qSum(int l, int r, int col) {
    int lp = pos[l], rp = pos[r], ans = 0;
    if (lp == rp) {
        for (int i = l; i <= r; i ++)
            if (c[i] == col) ans += a[i];
        return ans;
    } 
    for (int i = l; i <= R[lp]; i ++) if (c[i] == col) ans += a[i];
    for (int i = lp + 1; i < rp; i ++) ans += sum[col][i];
    for (int i = L[rp]; i <= r; i ++) if (c[i] == col) ans += a[i]; 
    return ans;
} // 求和
inline int qMax(int l, int r, int col) {
    int lp = pos[l], rp = pos[r], ans = 0;
    if (lp == rp) {
        for (int i = l; i <= r; i ++)
            if (c[i] == col) ans = max(ans, a[i]);
        return ans;
    }
    for (int i = l; i <= R[lp]; i ++) if (c[i] == col) ans = max(ans, a[i]);
    for (int i = lp + 1; i < rp; i ++) ans = max(ans, maxn[col][i]);
    for (int i = L[rp]; i <= r; i ++) if (c[i] == col) ans = max(ans, a[i]); 
    return ans;
} // 求最大值
void dfs1(int u, int f, int dep) {
    fa[u] = f, siz[u] = 1, depth[u] = dep;
    for (int i = head[u]; i; i = e[i].next) {
        int v = e[i].v;
        if (v == f) continue;
        dfs1(v, u, dep + 1);
        if (siz[v] > siz[son[u]])
            son[u] = v;
        siz[u] += siz[v];
    }
}
void dfs2(int u, int topf) {
    id[u] = ++ cnt, a[cnt] = w[u], c[cnt] = w1[u], top[u] = topf;
    if (!son[u]) return;
    dfs2(son[u], topf);
    for (int i = head[u]; i; i = e[i].next) {
        int v = e[i].v;
        if (v == son[u] || v == fa[u]) continue;
        dfs2(v, v);
    }
}
inline int qRangeV(int u, int v, int k) {
    int ans = 0;
    while (top[u] != top[v]) {
        if (depth[top[u]] < depth[top[v]]) swap(u, v);
        ans += qSum(id[top[u]], id[u], k), u = fa[top[u]];
    }
    if (depth[u] > depth[v]) swap(u, v);
    return ans + qSum(id[u], id[v], k);
} // 树上求和
inline int qRangeM(int u, int v, int k) {
    int ans = 0;
    while (top[u] != top[v]) {
        if (depth[top[u]] < depth[top[v]]) swap(u, v);
        ans = max(ans, qMax(id[top[u]], id[u], k)), u = fa[top[u]];
    }
    if (depth[u] > depth[v]) swap(u, v);
    return max(ans, qMax(id[u], id[v], k));
} // 树上求最大值
inline void add(int u, int v) {
    e[++ edge_cnt].v = v;
    e[edge_cnt].next = head[u];
    head[u] = edge_cnt;
}
inline int read() {
    register int t = 1, a = 0;
    register char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') t = -1;
        ch = getchar();
    }
    while (ch <= '9' && ch >= '0')
        a = a * 10 + ch - '0', ch = getchar();
    return a * t;
}
int main() {
    n = read(), q = read();
    for (int i = 1; i <= n; i ++)
        w[i] = read(), w1[i] = read();
    for (int i = 1; i < n; i ++)
        u = read(), v = read(), add(u, v);
    dfs1(1, 0, 1), dfs2(1, 1), build();
    for (int i = 1; i <= q; i ++) {
        scanf("%s", opt);
        if (opt[1] == 'C') u = read(), v = read(), updColor(id[u], v);
        if (opt[1] == 'W') u = read(), v = read(), updValue(id[u], v);
        if (opt[1] == 'S') u = read(), v = read(), printf("%d\n", qRangeV(u, v, c[id[u]]));
        if (opt[1] == 'M') u = read(), v = read(), printf("%d\n", qRangeM(u, v, c[id[u]]));
    }
    return 0;
}