CF1767F Two Subtrees 题解
很好的一道莫(卡)队(常)题。
首先发现题目询问的是子树的信息,那么就可以把这个树按照 dfn 序弄成一个序列,一颗子树在序列中就是一个区间。那么问题就转化为两个区间中的众数。
这和普通的莫队极不相同,因为它是有四个索引。这不影响我们进行莫队,我们对这四个索引进行分块。但块长怎么设定呢?
设块长为
所有索引加起来为
知道分块的方式了那怎么统计信息,可以想到用值域分块,即对于每个块统计最大值。当要查询时先找目标点的块,再暴力找。修改
不过,我们求的是最大值,如果减小一个值不能
代码:
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#define N 200010
#define B 3000
#define P 500
#define M 510
using namespace std;
int n, q, v[N];
int dfn[N], siz[N], tot = 0;
int val[N], bi[N], pi[N], ans[N];
int num[N], sma[M], tmp[M];
int nowal, nowar, nowbl, nowbr;
vector<int>e[N];
struct query
{
int al, ar, bl, br, id;
}p[N];
void dfs(int now, int fa)
{
dfn[now] = ++tot;
siz[now] = 1;
val[tot] = v[now];
for(int i = 0; i < e[now].size(); i++)
if(e[now][i] != fa)
{
dfs(e[now][i], now);
siz[now] += siz[e[now][i]];
}
}
int read()
{
int x = 0;
char c = getchar();
while(c < '0' || c > '9')c = getchar();
while(c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
return x;
}
bool cmp(query x, query y)
{
if(bi[x.al] != bi[y.al])return x.al < y.al;
if(bi[x.bl] != bi[y.bl])return x.bl < y.bl;
if(bi[x.ar] != bi[y.ar])return x.ar < y.ar;
return x.br < y.br;
}
void insert(int x)
{
x = val[x];
num[x]++;
sma[pi[x]] = max(sma[pi[x]], num[x]);
}
void erase(int x)
{
num[val[x]]--;
}
int get()
{
int maxn = 0, op = 1;
for(int i = 1; i <= P; i++)
if(sma[i] > maxn)
{
maxn = sma[i];
op = i;
}
maxn = 0;
int oop;
for(int i = (op - 1) * P; i < op * P; i++)
if(num[i] > maxn)
{
maxn = num[i];
oop = i;
}
return oop;
}
int main()
{
n = read();
for(int i = 1; i <= n; i++)v[i] = read();
for(int i = 1; i < n; i++)
{
int u = read(), v = read();
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, 0);
for(int i = 1; i <= 2e5; i++)
bi[i] = i / B + 1;
for(int i = 1; i <= 2e5; i++)
pi[i] = i / P + 1;
q = read();
for(int i = 1; i <= q; i++)
{
int u = read(), v = read();
p[i] = (query){dfn[u], dfn[u] + siz[u] - 1, dfn[v], dfn[v] + siz[v] - 1, i};
}
sort(p + 1, p + 1 + q, cmp);
for(int i = 1; i <= q; i++)
{
if(bi[p[i].al] != bi[p[i - 1].al] || bi[p[i].bl] != bi[p[i - 1].bl] || bi[p[i].ar] != bi[p[i - 1].ar])
{
memset(num, 0, sizeof num);
memset(sma, 0, sizeof sma);
nowal = bi[p[i].al] * B;
nowar = (bi[p[i].ar] - 1) * B;
nowbl = bi[p[i].bl] * B;
nowbr = bi[p[i].bl] * B - 1;
if(nowal <= nowar)
for(int j = nowal; j <= nowar; j++)insert(j);
}
if(bi[p[i].br] != bi[p[i].bl])
while(nowbr < p[i].br)insert(++nowbr);
memcpy(tmp, sma, sizeof tmp);
if(nowal <= nowar)
{
for(int j = p[i].al; j < nowal; j++)insert(j);
for(int j = nowar + 1; j <= p[i].ar; j++)insert(j);
}
else
for(int j = p[i].al; j <= p[i].ar; j++)insert(j);
if(bi[p[i].bl] == bi[p[i].br])
for(int j = p[i].bl; j <= p[i].br; j++)insert(j);
else
for(int j = p[i].bl; j < nowbl; j++)insert(j);
ans[p[i].id] = get();
if(nowal <= nowar)
{
for(int j = p[i].al; j < nowal; j++)erase(j);
for(int j = nowar + 1; j <= p[i].ar; j++)erase(j);
}
else
for(int j = p[i].al; j <= p[i].ar; j++)erase(j);
if(bi[p[i].bl] == bi[p[i].br])
for(int j = p[i].bl; j <= p[i].br; j++)erase(j);
else
for(int j = p[i].bl; j < nowbl; j++)erase(j);
memcpy(sma, tmp, sizeof tmp);
}
for(int i = 1; i <= q; i++)printf("%d\n", ans[i]);
return 0;
}