P3313 [SDOI2014] 旅行 题解

· · 题解

题意:树上问题,每点有颜色和点权。操作有:求路径上颜色为 c 的点权和或 \max,修改 x 点的颜色和点权。

首先这道题一看就需要树链剖分。先上树剖转化为区间问题。然后考虑区间问题怎么做。

考虑分块。最明显的方法应该是对序列分块。这里拿区间和举例。设 s_{c, k} 表示颜色为 c 的点在 k 块内的权值和。这样区间查询时只需要查询 s_{c, [l, r]} 即可,时间复杂度 O(\sqrt{n})。单点修改时可以直接散块暴力重构。时间复杂度 O(\sqrt{n})。带上树剖的 \log,时间复杂度 O(m \sqrt n \log n)

这样做的空间复杂度时 O(n\sqrt{n})。水平较高的同学可以直接分散层叠或者开动态数组,但是朴素算法也可以过。

大部分人止步于此。实际上,\sqrt{n} 真的是这道题的最优块长吗?让我们思考一下。设块长为 B,则单次修改的复杂度明显是 O(B) 的。单次询问分整块和散块考虑。若为整块,则时间复杂度 O(\dfrac{n}{B}),若为散块,时间复杂度 O(B)。然后直接套上树剖的复杂度 \log n,时间复杂度 O((\dfrac{n}{B} + B) \log n)。然后当 B = \sqrt{n} 时最优?不不不,当然不是这样的。

让我们考虑一下树剖的性质:每个整块在一次询问中最多被查询一次。所以整块的复杂度 O(\dfrac{n}{B}) 并不能套上树剖的 \log。然后考虑散块查询的复杂度,我们发现是有可能查询到 \log n 个散块的。故单次查询的复杂度就变成了 O(\dfrac{n}{B} + B \log n)

我们先把单点修改的复杂度抛去不看(因为复杂度远小于查询)。现在变成了一个中学数学问题:求 \dfrac{n}{B} + B \log n 的最小值。根据均值不等式可知,当 \dfrac{n}{B} = B \log n 时,即 B = \sqrt{\dfrac{n}{\log n}} 时为最优块长。此时查询的时间复杂度为 O(\sqrt{n \log n})。好耶,我们实现了将外层 \log 放到根号里!

实测这种块长的调整方式:块长为 \sqrt{n} 时运行时间 1.00 s,块长为 \sqrt{\dfrac{n}{\log n}} 时运行时间为 740 ms。不过上述过程没有考虑到树剖常数因素的影响,所以可能这个证明不是很严谨。

#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>

using namespace std;

const int N = 100010, M = N << 1;
const int K = 1400, INF = 0x3f3f3f3f;

int h[N], e[M], ne[M], idx;
int _col[N], _w[N], col[N], w[N], n, m;
int sz[N], son[N], top[N], id[N], cnt, dep[N], fa[N];
int sum[N][K], maxn[N][K], len, lp[K], rp[K], blo[N];

void chkmax(int &a, int b) { a = (a > b ? a : b); }
void chkmin(int &a, int b) { a = (a < b ? a : b); }
int get(int x) { return int(x / len); }
void add(int a, int b) {
    e[ ++ idx] = b, ne[idx] = h[a], h[a] = idx;
}
void dfs1(int u, int f) {
    fa[u] = f, sz[u] = 1, dep[u] = dep[f] + 1;
    for (int i = h[u]; i; i = ne[i]) {
        int v = e[i];
        if (v == f) continue;
        dfs1(v, u); sz[u] += sz[v];
        if (sz[son[u]] < sz[v]) son[u] = v;
    }
}
void dfs2(int u, int t) {
    top[u] = t, id[u] = ++ cnt; 
    w[cnt] = _w[u], col[cnt] = _col[u];
    if (son[u]) dfs2(son[u], t);
    for (int i = h[u]; i; i = ne[i]) {
        int v = e[i];
        if (v == fa[u] or v == son[u]) continue;
        dfs2(v, v);
    }
}
int qsum(int c, int l, int r) {
    int lc = blo[l], rc = blo[r], ans = 0;
    if (lc == rc) {
        for (int i = l; i <= r; i ++ )
            if (col[i] == c) ans += w[i];
        return ans;
    }
    int i = l, j = r;
    for (; get(i) == lc; i ++ ) if (col[i] == c) ans += w[i];
    for (; get(j) == rc; j -- ) if (col[j] == c) ans += w[j];
    for (int k = get(i); k <= get(j); k ++ ) ans += sum[c][k];
    return ans;
}
int qmax(int c, int l, int r) {
    int lc = blo[l], rc = blo[r], ans = -INF;
    if (lc == rc) {
        for (int i = l; i <= r; i ++ )
            if (col[i] == c) chkmax(ans, w[i]);
        return ans;
    }
    int i = l, j = r;
    for (; get(i) == lc; i ++ ) if (col[i] == c) chkmax(ans, w[i]);
    for (; get(j) == rc; j -- ) if (col[j] == c) chkmax(ans, w[j]);
    for (int k = get(i); k <= get(j); k ++ ) chkmax(ans, maxn[c][k]);
    return ans;
}
void cgcol(int x, int c) {
    int last = col[x], k = blo[x];
    col[x] = c; maxn[last][k] = maxn[c][k] = -INF, sum[last][k] = sum[c][k] = 0;
    for (int i = lp[k]; i <= rp[k]; i ++ ) {
        if (col[i] == last) chkmax(maxn[last][k], w[i]), sum[last][k] += w[i];
        if (col[i] == c) chkmax(maxn[c][k], w[i]), sum[c][k] += w[i];
    }
}
void cgw(int x, int c) {
    w[x] = c; int k = blo[x]; maxn[col[x]][k] = -INF, sum[col[x]][k] = 0;
    for (int i = lp[k]; i <= rp[k]; i ++ ) {
        if (col[i] == col[x]) sum[col[x]][k] += w[i], chkmax(maxn[col[x]][k], w[i]);
    }
}
int qs(int c, int u, int v) {
    int ans = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        ans += qsum(c, id[top[u]], id[u]); u = fa[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v);
    ans += qsum(c, id[v], id[u]);
    return ans;
}
int qm(int c, int u, int v) {
    int ans = -INF;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        chkmax(ans, qmax(c, id[top[u]], id[u])); u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    chkmax(ans, qmax(c, id[u], id[v])); return ans;
}

int main() {
    scanf("%d%d", &n, &m); len = sqrt(double(n / log2(n)));
    for (int i = 1; i <= n; i ++ )
        scanf("%d%d", &_w[i], &_col[i]);
    for (int i = 1; i < n; i ++ ) {
        int a, b; scanf("%d%d", &a, &b);
        add(a, b), add(b, a);
    }
    dfs1(1, -1), dfs2(1, 1);
    memset(lp, 0x3f, sizeof lp);
    memset(rp, -0x3f, sizeof rp);
    for (int i = 1; i <= n; i ++ ) {
        int k = get(i); blo[i] = k;
        chkmin(lp[k], i), chkmax(rp[k], i);
        sum[col[i]][k] += w[i];
        chkmax(maxn[col[i]][k], w[i]);
    }
    while (m -- ) {
        char op[3]; int x, y;
        scanf("%s%d%d", op, &x, &y);
        if (op[0] == 'C' and op[1] == 'W') cgw(id[x], y);
        if (op[0] == 'C' and op[1] == 'C') cgcol(id[x], y);
        if (op[0] == 'Q' and op[1] == 'S') printf("%d\n", qs(col[id[x]], x, y));
        if (op[0] == 'Q' and op[1] == 'M') printf("%d\n", qm(col[id[x]], x, y));
    }
    return 0;
}