P3603 雪辉 || 树分块小记

· · 题解

树分块

众所周知序列有序列分块,那么树上能否有树分块呢?

我们可以在树上选择若干个关键节点,然后可以预处理关键节点之间的信息,从而减低复杂度。

若我们选择一个阈值 S,我们希望每个点与其最近的祖先关键点之间的距离不超过 S,该怎么实现呢?

比较简单的做法是直接在树上随机选 \dfrac{n}{S},每个点与其最近的祖先关键点之间的期望距离是不超过 S 的。

还有一种方法,就是选择深度最大的非关键点,若该点 1\sim S 级祖先都不为关键点,那么标记 S 级祖先为关键点,这样就能保证每个点与其最近的祖先关键点之间的距离是不超过 S 了。

选完关键点后,预处理每对存在祖先关系的关键点之间的信息,以及每个关键点的最近关键点祖先。对于询问 (x,y),将 x 跳到第一个关键点,然后进行关键点之间的跳跃,跳到离 lca(x,y) 最近的关键点,最后慢慢跳上去,对于 y 同理。

以上就是树分块的核心操作了,下面回到本题。

P3603 雪辉

题目传送门

注意到点权不是很大,于是可以考虑使用 bitset。点权种类数即为 ans.count()\text{mex}(~ans)._Find_first()

bt_{i,j} 为关键点 i 与关键点 j 之间的点权集合,若关键点 kj 的后代,显然会有 bt_{i,k}=bt_{i,j}\operatorname{or}bt_{j,k}。由于这样的点对最多不超过 \dfrac{n^2}{S^2},预处理总复杂度为 O(\dfrac{n^2}{S}+\dfrac{n^2V}{S^2w}),其中 \dfrac{1}{w}bitset 的常数因子,V 为值域。

简单的实现方式:

void dfs(int u) {
    for (int v:E[u]) if (v != fa[u]) {
        if (id[v]) {
            for (int x = v; x != stk[top]; x = fa[x]) bt[id[stk[top]]][id[v]].set(a[x]);
            for (int i = 1; i < top; ++i) bt[id[stk[i]]][id[v]] = bt[id[stk[i]]][id[stk[top]]]|bt[id[stk[top]]][id[v]];
            up[v] = stk[top], stk[++top] = v;
        }
        dfs(v);
        if (id[v]) --top;
    }
}

查询的复杂度为 O(\dfrac{Vm}{w}+mS),总复杂度为 O(\dfrac{n^2}{S}+\dfrac{n^2V}{S^2w}+\dfrac{Vm}{w}+mS),取 S=\sqrt n,则总复杂度为 O((n+m)\sqrt n+\dfrac{V}{w}(n+m))

代码:

#include <bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define mk make_pair
#define ll long long
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;

typedef vector <int> vi;
typedef pair <int, int> pii;

inline int rd() { int x = 0, f = 1; char c = getchar(); while (!isdigit(c)) f = c == '-' ? -1 : f, c = getchar(); while (isdigit(c)) x = (x<<3)+(x<<1)+(c^48), c = getchar(); return x*f; }
inline ll rdll() { ll x = 0, f = 1; char c = getchar(); while (!isdigit(c)) f = c == '-' ? -1 : f, c = getchar(); while (isdigit(c)) x = (x<<3)+(x<<1)+(c^48), c = getchar(); return x*f; }
template <typename T> inline void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x/10); putchar(x%10+48); }

const int N = 1e5+5, V = 3e4+5, B = 1000;
int n, m, q, lst, idx, top, a[N], d[N], fa[N], sz[N], tp[N], id[N], up[N], son[N], mxd[N], stk[N];
vi E[N]; bitset <V> ans, res, bt[105][105];

void dfs1(int u, int f) {
    d[u] = mxd[u] = d[f]+1, fa[u] = f, sz[u] = 1;
    for (int v:E[u]) if (v != f) {
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
        mxd[u] = max(mxd[u], mxd[v]);
    }
    if (mxd[u]-d[u] >= B) id[u] = ++idx, mxd[u] = d[u];
}

void dfs2(int u, int p) {
    tp[u] = p; if (son[u]) dfs2(son[u], p);
    for (int v:E[u]) if (v != fa[u] && v != son[u]) dfs2(v, v);
}

void dfs(int u) {
    for (int v:E[u]) if (v != fa[u]) {
        if (id[v]) {
            for (int x = v; x != stk[top]; x = fa[x]) bt[id[stk[top]]][id[v]].set(a[x]);
            for (int i = 1; i < top; ++i) bt[id[stk[i]]][id[v]] = bt[id[stk[i]]][id[stk[top]]]|bt[id[stk[top]]][id[v]];
            up[v] = stk[top], stk[++top] = v;
        }
        dfs(v);
        if (id[v]) --top;
    }
}

int lca(int u, int v) {
    while (tp[u] != tp[v]) {
        if (d[tp[u]] < d[tp[v]]) swap(u, v);
        u = fa[tp[u]];
    }
    return d[u] < d[v] ? u : v;
}

void solve(int u, int v) {
    res.reset(); int l = lca(u, v);
    while (u != l && !id[u]) res.set(a[u]), u = fa[u];
    while (v != l && !id[v]) res.set(a[v]), v = fa[v];
    if (u != l) {
        int pre = u;
        while (d[up[u]] >= d[l]) u = up[u];
        if (u != pre) res |= bt[id[u]][id[pre]];
        while (u != l) res.set(a[u]), u = fa[u];
    }
    if (v != l) {
        int pre = v;
        while (d[up[v]] >= d[l]) v = up[v];
        if (v != pre) res |= bt[id[v]][id[pre]];
        while (v != l) res.set(a[v]), v = fa[v];
    }
    res.set(a[l]);
}

int main() {
    n = rd(), m = rd(), q = rd();
    for (int i = 1; i <= n; ++i) a[i] = rd();
    for (int i = 1; i < n; ++i) {
        int x = rd(), y = rd();
        E[x].pb(y), E[y].pb(x);
    }
    dfs1(1, 0), dfs2(1, 1);
    if (!id[1]) id[1] = ++idx;
    stk[top = 1] = 1; dfs(1);
    while (m--) {
        int c = rd(); ans.reset();
        while (c--) {
            int x = rd()^lst, y = rd()^lst;
            solve(x, y);
            ans |= res;
        }
        int x = ans.count(), y = (~ans)._Find_first(); lst = (x+y)*q;
        write(x), space, write(y), enter;
    }
    return 0;
}