题解 CF840E In a Trap(类树分块 + 阈值分治)

· · 题解

在知道这道题是阈值分治之后跑来做了做。

首先其实很佩服这道题的 trick ,由于查询保证了点对之间的祖先关系,大概想到了对每个点进行一个累似倍增的操作,然后合的时候就可以愉快合,毕竟我们的阈值分治比较重要的一个点就是乘法,如果这道题要阈值分治凑答案,那么合起来显然才能体现出阈值分治的优势。

接着我就发现这个玩意儿好像不能合啊,如果向上跳,dis 就整体加上一个值,可是这个整体加值对位运算而言显然不能维护啊,挂了挂了。

翻了题解,发现这道题还需要再在值域上套一个阈值分治,我们以 256 为我们爬树时合并答案的阈值(可以类比为对这个树分块)以及我们处理计算时的阈值。

首先我们用 f_{x,i} 表示第 x 个点此时的 dis 都需要加上 i \times 256 时的原式最大值,显然:

这里为什么可以直接 \bigoplus (i \times 256),很显然的就是我们的 dis(x,j) 不会超过 256

接着我们考虑快速计算这个玩意儿。

这时阈值分值一个值大于 256 和小于 256 的部分,也就是一个数二进制形式下 8 位后前的值:(这里的前指的是从第 0 位开始)

之后得到了这个值,我们就在查询时一步步往上合并即可,时间复杂度 O(256n + qn / 256)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int Len = 5e4 + 5,m = 256;
int fa[Len],n,q,cnt,trie[Len][2],f[Len][261],g[Len][261],dep[Len],lst[Len],val[Len];
int head[Len],tot;
bool vis[Len];
inline int read() {
    char ch = getchar();
    int x = 0, f = 1;
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while ('0' <= ch && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}
inline void write(int x) {
    if (x < 0)
        putchar('-'), x = -x;
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}
struct node
{
    int next,to;
}edge[Len << 1];
void add(int from,int to)
{
    edge[++ tot].to = to;
    edge[tot].next = head[from];
    head[from] = tot; 
}
void clear(){cnt = 0 ; trie[0][0] = trie[0][1] = 0;}
inline void insert(int x)
{
    int rt = 0;
    for(int i = 7 ; i >= 0 ; i --)
    {
        int id = (x >> i) & 1;
        if(!trie[rt][id])
        {
            trie[rt][id] = ++ cnt;
            trie[cnt][0] = trie[cnt][1] = 0;
        }
        rt = trie[rt][id];
    }
}
inline int query(int x,int u)
{
    int ret = 0,rt = 0,v = 0;
    for(int i = 7 ; i >= 0 ; i --)
    {
        int id = (x >> i) & 1;
        if(trie[rt][id ^ 1]) rt = trie[rt][id ^= 1] , ret |= (1 << i);
        else rt = trie[rt][id];
        v |= id << i;
    }
    return ret << 8 | g[u][v];
}
void dfs(int x,int fft)
{
    dep[x] = dep[fft] + 1;
    fa[x] = fft;
    if(dep[x] >= m)
    {
        clear();
        int i;
        for(i = x ; dep[x] - dep[i] < m ; i = fa[i])
        {
            g[x][val[i] / 256] = max(g[x][val[i] / 256] , (dep[x] - dep[i]) ^ val[i] & 255);
            if(vis[val[i] / 256] != x) insert(val[i] / 256) , vis[val[i] / 256] = x;
        }
        lst[x] = i;//分块
        for(i = 0 ; i < m ; i ++) f[x][i] = query(i , x);
    }
    for(int e = head[x] ; e ; e = edge[e].next)
    {
        int to = edge[e].to;
        if(to == fft) continue;
        dfs(to , x);
    }
}
signed main()
{
    n = read() , q = read();
    for(int i = 1 ; i <= n ; i ++) val[i] = read();
    for(int i = 1 ; i < n ; i ++)
    {
        int x,y;x = read() , y = read();
        add(x , y) , add(y , x);
    }
    dep[1] = 1;dfs(1 , 0);
    for(int i = 1 ; i <= q ; i ++)
    {
        int x,y,dis = 0,ans = 0;x = read() , y = read();
        while(dep[y] - dep[x] >= m) 
        {
            ans = max(ans , f[y][dis]);
            y = lst[y] , dis ++;
        }
        dis <<= 8;
        while(y != fa[x])
        {
            ans = max(ans , dis ^ val[y]);
            y = fa[y] , dis ++;
        }
        write(ans) , putchar('\n');
    }
    return 0;
}