P6626 [省选联考 2020 B 卷] 消息传递

· · 题解

分治做法,但不是点分治。

考虑与 x 距离为 k 的点,贡献分为两个部分:在 x 子树内与在 x 子树外。

对于 x 子树内的部分,问题等价于查询 x 子树内深度为 dep[x] + d 的点的个数,也就是区间某数出现次数。

由于作者画功不好,只能通过定义阐述内容。

假设树为以 1 为根的有根树,将 x 到 1 的最短路径上除 x 外的所有点称为树杈。

将该最短路径上所有边去除后,将每个树杈 y 位于的联通块称为 y 的树链。

对于一个树链,将其树杈 y 及 y 所直接连接的所有边去除后,剩余的若干联通块统称为 y 的树枝,称其中直径最大的联通块的直径为该树链的最大深度。

对于每个树杈,我们称与其直接连接的树杈中深度较大的树杈为其前驱。(对于与 x 直接连接的树杈特别定义其前驱为 x)。

对于 x 子树外的部分,答案为所有树杈 y 的树链内深度为 d - (dep[x] - dep[y]) + dep[y] 的点的个数的和。

容易发现,y 的树链其实就是 y 的子树去除 y 的前驱的子树,所以对于每个树杈 y,其贡献为 y 的子树内深度为 d - (dep[x] - dep[y]) + dep[y] 的点的个数减去 y 的前驱的子树内深度为 d - (dep[x] - dep[y]) + dep[y] 的点的个数,也就是区间某数出现次数。

已知区间某数出现次数存在 O(n\sqrt{n}) - O(1) 的可持久化块状树做法,即将线段树分为 \sqrt{n} 叉后可持久化。

显然,此时复杂度的全部瓶颈在于枚举树杈,最劣复杂度为 O(n\sqrt{n} + nq)

对k进行根号分治,当 k \le B 时,有贡献的树杈个数只有最多B个,直接暴力枚举即可,复杂度为 O(qB)

k > B 时,继续进行分治:树链的最大深度 \le T 的树杈和 > T 的树杈。

容易发现,最大深度 \le T 的树杈中对答案有贡献的树杈只有 \le T 个,也就是深度在 [dep_x - d, dep_x - (d - T)] 中的树杈。

显然,这些树杈我们可以 O(T) 枚举,这一部分的复杂度为 O(qT)

对于最大深度 > T 的树杈,个数只有 \le \dfrac{n}{T} 个。

对于每个树杈 y 记录一个 z,表示为深度最大的满足 siz_z - siz_y > T 的树杈。查询时只需要从 x 出发,不停前往 z 并记录贡献即可。

显然,枚举到的 z 一定覆盖了所有最大深度 > T 的树杈。

这一部分复杂度也是 O(\dfrac{qn}T)

然后就得到了一个 O(n\sqrt{n} + qB + qT + \dfrac{qn}{T}) 时间复杂度和 O(n\sqrt{n}) 空间复杂度的算法(可持久化块状树空间复杂度为 O(n \sqrt{n}).然而,这道题对空间比较严苛,实现不好的 O(n\sqrt{n}) 空间复杂度不能通过。

于是我们考虑用 O(n) - O(\log n) 且线性空间的 vector 代替可持久化块状树。这是在不采用 O(n\sqrt{n}) 空间复杂度时查询较快的一种方法,因为二分的常数通常非常小。

由于分界点两侧都用到了这个查询,所以没有办法平衡复杂度。

容易发现,当 T = B = \sqrt{n}O(n\sqrt{n} \log n) 的复杂度很难通过 n = 10^5

作者在多次提交后发现,将 T 和 B 调到非常小时程序效率有显著提高。

这是因为,由于数据中树的形态比较集中,大部分树链都非常大,向上跳的次数远小于 \dfrac{n}{T}

经过测试,T = 30 效率最佳。

注意到第一次根号分治本质上没有多少用处,但是可以有效帮助我们处理边界情况,降低思维难度。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>
#include <set>
#include <cstdlib>
#include <cmath>
#include <deque>
using namespace std;
const int N = 1e5 + 10;
#define mp make_pair
#define pii pair<int, int>
int read()
{
    int x = 0, f = 1;
    char c = getchar();
    while (c < '0' || c > '9')
    {
        if (c == '-')
            f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9')
    {
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return x * f;
}
int T, tot, n, m, t, cnt;
int head[N], ver[N * 2], last[N * 2], f[N][30], siz[N], pre[N], l[N], r[N], dep[N];
int a[N], ans[N], vis[N], bac[N];
vector<int> v[N];
int g[N];
void add(int x, int y)
{
    ver[++tot] = y;
    last[tot] = head[x];
    head[x] = tot;
}
int find(int x)
{
    int S = siz[x];
    for (int j = 20; j >= 0; j--)
        if (f[x][j] && siz[f[x][j]] - S < t)
            x = f[x][j];
    return x;
}
int find2(int x, int y)
{
    int D = dep[x];
    for (int j = 20; j >= 0; j--)
        if (f[x][j] && D - dep[f[x][j]] < y)
            x = f[x][j];
    return x;
}
void dfs(int x, int fa)
{
    siz[x] = 1;
    l[x] = ++cnt;
    bac[cnt] = x;
    dep[x] = dep[fa] + 1;
    for (int i = head[x]; i; i = last[i])
    {
        int y = ver[i];
        if (y == fa)
            continue;
        f[y][0] = x;
        for (int j = 1; j <= 20; j++)
            f[y][j] = f[f[y][j - 1]][j - 1];
        dfs(y, x);
        siz[x] += siz[y];
    }
    r[x] = l[x] + siz[x] - 1;
}
void init()
{
    tot = cnt = 0;
    for (int i = 1; i <= n; i++)
        head[i] = 0;
    for (int i = 1; i <= n; i++)
        for (int j = 0; j <= 20; j++)
            f[i][j] = 0;
    for (int i = 1; i <= n; i++)
        a[i] = 0, ans[i] = 0, bac[i] = 0;
    for (int i = 1; i <= n; i++)
        v[i].clear();
}
int query(int num, int l, int r, int d, int xs)
{
    int L = lower_bound(v[d].begin(), v[d].end(), l) - v[d].begin(), R = upper_bound(v[d].begin(), v[d].end(), r) - v[d].begin() - 1;
    return R - L + 1;
}
int ask(int x, int y, int z, int d, int num)
{
    int cha = dep[y] - dep[x], res = 0;
    res += query(num, l[x], r[x], d - cha + dep[x], 1);
    if (z)
        res -= query(num, l[z], r[z], d - cha + dep[x], -1);
    return res;
}
int main()
{
    T = read();
    while (T--)
    {
        n = read(), m = read();
        init();
        t = 30;
        for (int i = 1; i < n; i++)
        {
            int u = read(), v = read();
            add(u, v);
            add(v, u);
        }
        dfs(1, 0);
        for (int i = 1; i <= n; i++)
            pre[i] = find(i);
        for (int i = 1; i <= n; i++)
            v[dep[bac[i]]].push_back(i);
        for (int i = 1; i <= m; i++)
        {
            int x = read(), d = read(), res = 0;
            res += ask(x, x, 0, d, i);
            if (d <= t)
            {
                int cur = x;
                for (int j = 1; j <= t; j++)
                {
                    if (cur == 1)
                        break;
                    res += ask(f[cur][0], x, cur, d, i);
                    cur = f[cur][0];
                }
            }
            else
            {
                int cur = find2(x, d - t), cntt = 0;
                while (cur != 1 && dep[x] - dep[cur] <= d)
                {
                    vis[f[cur][0]] = 1;
                    g[++cntt] = f[cur][0];
                    res += ask(f[cur][0], x, cur, d, i);
                    cur = f[cur][0];
                }
                cur = x;
                while (cur != 1)
                {
                    cur = pre[cur];
                    if (cur == 1)
                        break;
                    if (vis[f[cur][0]])
                    {
                        cur = f[cur][0];
                        continue;
                    }
                    res += ask(f[cur][0], x, cur, d, i);
                    cur = f[cur][0];
                }
                for (int i = 1; i <= cntt; i++)
                    vis[g[i]] = 0;
            }
            printf("%d\n", res);
        }
    }
    return 0;
}