CF1767F Two Subtrees 题解

· · 题解

很好的一道莫()队()题。

首先发现题目询问的是子树的信息,那么就可以把这个树按照 dfn 序弄成一个序列,一颗子树在序列中就是一个区间。那么问题就转化为两个区间中的众数。

这和普通的莫队极不相同,因为它是有四个索引。这不影响我们进行莫队,我们对这四个索引进行分块。但块长怎么设定呢?

设块长为 B(下文当 nq 同阶),前面三个索引按照块编号排序,最后一个索引按照本身的编号排序。前三个索引块编号有 \frac{n^{3}}{B^{3}} 种情况,每种情况最后一个索引会移动 n,总共 \frac{n^{4}}{B^{3}},同种情况下每个索引最多移动 B,总共 nB

所有索引加起来为 \frac{n^{4}}{B^{3}}+nB,考虑平衡 \frac{n^{4}}{B^{3}}=nB,此时 B=n^{\frac{3}{4}}

知道分块的方式了那怎么统计信息,可以想到用值域分块,即对于每个块统计最大值。当要查询时先找目标点的块,再暴力找。修改 O(1),查询 O(\sqrt n)

不过,我们求的是最大值,如果减小一个值不能 O(1) 更新,那就用回滚莫队。时间复杂度 O(n^{\frac{7}{4}})

代码:

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#define N 200010
#define B 3000
#define P 500
#define M 510
using namespace std;
int n, q, v[N];
int dfn[N], siz[N], tot = 0;
int val[N], bi[N], pi[N], ans[N];
int num[N], sma[M], tmp[M];
int nowal, nowar, nowbl, nowbr;
vector<int>e[N];
struct query
{
    int al, ar, bl, br, id;
}p[N];
void dfs(int now, int fa)
{
    dfn[now] = ++tot;
    siz[now] = 1;
    val[tot] = v[now];
    for(int i = 0; i < e[now].size(); i++)
        if(e[now][i] != fa)
        {
            dfs(e[now][i], now);
            siz[now] += siz[e[now][i]];
        }
}
int read()
{
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9')c = getchar();
    while(c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
    return x;
}
bool cmp(query x, query y)
{
    if(bi[x.al] != bi[y.al])return x.al < y.al;
    if(bi[x.bl] != bi[y.bl])return x.bl < y.bl;
    if(bi[x.ar] != bi[y.ar])return x.ar < y.ar;
    return x.br < y.br;
}
void insert(int x)
{
    x = val[x];
    num[x]++;
    sma[pi[x]] = max(sma[pi[x]], num[x]);
}
void erase(int x)
{
    num[val[x]]--;
}
int get()
{
    int maxn = 0, op = 1;
    for(int i = 1; i <= P; i++)
        if(sma[i] > maxn)
        {
            maxn = sma[i];
            op = i;
        }
    maxn = 0;
    int oop;
    for(int i = (op - 1) * P; i < op * P; i++)
        if(num[i] > maxn)
        {
            maxn = num[i];
            oop = i;
        }
    return oop;
}
int main()
{
    n = read();
    for(int i = 1; i <= n; i++)v[i] = read();
    for(int i = 1; i < n; i++)
    {
        int u = read(), v = read();
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1, 0);
    for(int i = 1; i <= 2e5; i++)
        bi[i] = i / B + 1;
    for(int i = 1; i <= 2e5; i++)
        pi[i] = i / P + 1;
    q = read();
    for(int i = 1; i <= q; i++)
    {
        int u = read(), v = read();
        p[i] = (query){dfn[u], dfn[u] + siz[u] - 1, dfn[v], dfn[v] + siz[v] - 1, i};
    }
    sort(p + 1, p + 1 + q, cmp);
    for(int i = 1; i <= q; i++)
    {
        if(bi[p[i].al] != bi[p[i - 1].al] || bi[p[i].bl] != bi[p[i - 1].bl] || bi[p[i].ar] != bi[p[i - 1].ar])
        {
            memset(num, 0, sizeof num);
            memset(sma, 0, sizeof sma);
            nowal = bi[p[i].al] * B;
            nowar = (bi[p[i].ar] - 1) * B;
            nowbl = bi[p[i].bl] * B;
            nowbr = bi[p[i].bl] * B - 1;
            if(nowal <= nowar)
                for(int j = nowal; j <= nowar; j++)insert(j);
        }
        if(bi[p[i].br] != bi[p[i].bl])
            while(nowbr < p[i].br)insert(++nowbr);
        memcpy(tmp, sma, sizeof tmp);
        if(nowal <= nowar)
        {
            for(int j = p[i].al; j < nowal; j++)insert(j);
            for(int j = nowar + 1; j <= p[i].ar; j++)insert(j);
        }
        else
            for(int j = p[i].al; j <= p[i].ar; j++)insert(j);
        if(bi[p[i].bl] == bi[p[i].br])
            for(int j = p[i].bl; j <= p[i].br; j++)insert(j);
        else
            for(int j = p[i].bl; j < nowbl; j++)insert(j);
        ans[p[i].id] = get();
        if(nowal <= nowar)
        {
            for(int j = p[i].al; j < nowal; j++)erase(j);
            for(int j = nowar + 1; j <= p[i].ar; j++)erase(j);
        }
        else
            for(int j = p[i].al; j <= p[i].ar; j++)erase(j);
        if(bi[p[i].bl] == bi[p[i].br])
            for(int j = p[i].bl; j <= p[i].br; j++)erase(j);
        else
            for(int j = p[i].bl; j < nowbl; j++)erase(j);
        memcpy(sma, tmp, sizeof tmp);
    }
    for(int i = 1; i <= q; i++)printf("%d\n", ans[i]);
    return 0;
}