wxhtzdy ORO Tree 题解

· · 题解

wxhtzdy ORO Tree

题目大意

给定一颗 n 个点的树,每个点有权值 a_i,定义 f(x,y) 表示从 xy 的路径上的所有点的权值的按位或和,现有 q 个询问,每个询问给出 x,y,求

\max_{z\in\text{road}(x,y)}\left(\text{popcount}(f(x,z))+\text{popcount}(f(y,z))\right)

其中,\text{road}(x,y) 表示 xy 的路径上的所有点构成的集合,\text{popcount}(x) 表示 x 的二进制表示中 1 的个数。

思路分析

赛时胡出来了,但是没有时间写了,赛后花了 40 分钟写完加调完。

做法比官方题解劣,是三个 \log 的,不卡常 4500ms

我们对值域按二进制位逐位考虑:

考虑某一位时 x,y 的链是一条 01 链,我们找到从 xy 链上的第一个 1 的位置 u,将 uy 的链加 1,再找到从 yx 链上的第一个 1 的位置 v,将 vx 的链加 1,每一位都做完后查询全局最大值即可,正确性显然。

链加和查询全局最大值可以用树剖套线段树维护。

而找到第一个 1 的位置可以用树剖套二分做,比较麻烦,需要预处理一个树上前缀数组。

时间复杂度为 O(q\log^2n\log V),空间复杂度为 O(n\log V)。实现精细可以通过(比如写一个懒标记用于 O(1) 清空线段树)。

这个做法有很高的可扩展性,比如可以做到求最大的 z 的数量。

代码

(写了 4.5k,165 行,细节比较多,可以看一下注释)

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>

using namespace std;
const int N = 200200, V = 30, M = 32;
#define inf 0x3f3f3f3f
#define mid ((l + r) >> 1)
#define ls (p << 1)
#define rs (p << 1 | 1)

int T, n, q, in1, in2, cnt;
int fa[N], dep[N], dfn[N], rnk[N], top[N], son[N], siz[N];
int pre[N][M], a[N];

// pre[i][j] 表示从 1 到 i 的路径上权值的第 j 位为 1 的点的数量

vector <int> to[N];

void dfs_1(int s, int gr){
    fa[s] = gr; siz[s] = 1; son[s] = 0;
    dep[s] = dep[gr] + 1;
    for (int i = 0; i <= V; i ++)
        pre[s][i] = pre[gr][i] + (a[s] >> i & 1);// 预处理前缀数组
    for (auto v : to[s]) {
        if (v == gr) continue;
        dfs_1(v, s);
        siz[s] += siz[v];
        if (siz[son[s]] < siz[v]) son[s] = v;
    }
}

void dfs_2(int s, int tp){
    top[s] = tp; dfn[s] = ++ cnt; rnk[cnt] = s;
    if (!son[s]) return ;
    dfs_2(son[s], tp);
    for (auto v : to[s])
        if (v != fa[s] && v != son[s]) dfs_2(v, v);
}

int lca(int x, int y){// 树剖求 lca 
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    return dep[x] > dep[y] ? y : x;
}

struct STn{
    int tag, max, tag2;
    // tag 是区间加懒标记,max 是最值,tag2 是清空懒标记
};
struct ST{
    STn a[N << 2];
    void add_t(int p, int k){
        a[p].tag += k, a[p].max += k;
    }
    void clear_t(int p){
        a[p].tag2 = 1;
        a[p].tag = a[p].max = 0;
    }
    void push_down(int p){
        if (a[p].tag2) {// 先下放清空
            clear_t(ls); clear_t(rs);
            a[p].tag2 = 0;
        }
        if (a[p].tag) {
            add_t(ls, a[p].tag);
            add_t(rs, a[p].tag);
            a[p].tag = 0;
        }   
    }
    void add(int p, int l, int r, int x, int y, int k){// 区间加
        if (x <= l && r <= y) return add_t(p, k);
        push_down(p);
        if (x <= mid) add(ls, l, mid, x, y, k);
        if (y > mid) add(rs, mid + 1, r, x, y, k);
        a[p].max = max(a[ls].max, a[rs].max);
    }
}tree;

void add_all(int x, int y, int k){// x 到 y 链加
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        tree.add(1, 1, n, dfn[top[x]], dfn[x], k);
        x = fa[top[x]];
    }
    tree.add(1, 1, n, min(dfn[x], dfn[y]), max(dfn[x], dfn[y]), k);
}

int up(int x, int c){// 求从 x 到 1 的路径上第一个权值的第 c 位为 1 的点,时间复杂度是单 log 的
    while (top[x]) {
        if (pre[x][c] - pre[fa[top[x]]][c] != 0){// 链内存在 1
            int l = dfn[top[x]], r = dfn[x], ans = dfn[top[x]];// 重链上二分
            while (l <= r) {
                int Mid = (l + r) >> 1;
                if (pre[x][c] - pre[fa[rnk[Mid]]][c] != 0) l = Mid + 1, ans = Mid;
                else r = Mid - 1;
            }
            return rnk[ans];
        }
        x = fa[top[x]];
    }
    return -1;
}

int down(int x, int l, int c){// 求从 l 到 x 的路径上第一个权值的第 c 位 1 的点,l 是 x 的祖先
    vector <pair<int, int>> v;
    while (top[x] != top[l]) {// 从 x 跳到 l
        v.push_back({x, top[x]});
        x = fa[top[x]];
    }
    v.push_back({x, l});
    reverse(v.begin(), v.end());// 将跳的路径记录下来并反向
    for (auto it : v)
        if (pre[it.first][c] - pre[fa[it.second]][c] != 0) {// 存在 1
            int l = dfn[it.second], r = dfn[it.first], ans = dfn[it.second];// 二分位置
            while (l <= r) {
                int Mid = (l + r) >> 1;
                if (pre[rnk[Mid]][c] - pre[fa[it.second]][c] != 0) r = Mid - 1, ans = Mid;
                else l = Mid + 1;
            }
            return rnk[ans];
        }
    return -1;
}

int get(int x, int y, int c){// 分讨求 x 到 y 的路径上第一个权值的第 c 位为 1 的点
    int l = lca(x, y);
    if (pre[x][c] - pre[fa[l]][c] == 0) return down(y, l, c);
    else return up(x, c);
}

int main(){
    scanf("%d", &T);
    while (T --) {
        for (int i = 1; i <= n; i ++) to[i].clear();
        cnt = 0;
        scanf("%d", &n);
        for (int i = 1; i <= n; i ++) 
            scanf("%d", &a[i]);
        for (int i = 1; i < n; i ++) {
            scanf("%d %d", &in1, &in2);
            to[in1].push_back(in2);
            to[in2].push_back(in1);
        }
        dfs_1(1, 0);
        dfs_2(1, 1);
        scanf("%d", &q);
        while (q --) {
            scanf("%d %d", &in1, &in2);
            tree.clear_t(1);
            for (int i = 0; i <= V; i ++) {
                int l = lca(in1, in2);
                if (pre[in1][i] + pre[in2][i] - 2 * pre[fa[l]][i] == 0) continue;// 特判没有 1 的情况
                int u = get(in1, in2, i), v = get(in2, in1, i);
                add_all(u, in2, 1);// 链加
                add_all(in1, v, 1);
            }
            cout << tree.a[1].max << ' ';
        }
        cout << '\n';
    }
    return 0;
}