P6626 [省选联考 2020 B 卷] 消息传递
分治做法,但不是点分治。
考虑与 x 距离为 k 的点,贡献分为两个部分:在 x 子树内与在 x 子树外。
对于 x 子树内的部分,问题等价于查询 x 子树内深度为 dep[x] + d 的点的个数,也就是区间某数出现次数。
由于作者画功不好,只能通过定义阐述内容。
假设树为以 1 为根的有根树,将 x 到 1 的最短路径上除 x 外的所有点称为树杈。
将该最短路径上所有边去除后,将每个树杈 y 位于的联通块称为 y 的树链。
对于一个树链,将其树杈 y 及 y 所直接连接的所有边去除后,剩余的若干联通块统称为 y 的树枝,称其中直径最大的联通块的直径为该树链的最大深度。
对于每个树杈,我们称与其直接连接的树杈中深度较大的树杈为其前驱。(对于与 x 直接连接的树杈特别定义其前驱为 x)。
对于 x 子树外的部分,答案为所有树杈 y 的树链内深度为 d - (dep[x] - dep[y]) + dep[y] 的点的个数的和。
容易发现,y 的树链其实就是 y 的子树去除 y 的前驱的子树,所以对于每个树杈 y,其贡献为 y 的子树内深度为 d - (dep[x] - dep[y]) + dep[y] 的点的个数减去 y 的前驱的子树内深度为 d - (dep[x] - dep[y]) + dep[y] 的点的个数,也就是区间某数出现次数。
已知区间某数出现次数存在
显然,此时复杂度的全部瓶颈在于枚举树杈,最劣复杂度为
对k进行根号分治,当
当
容易发现,最大深度 [dep_x - d, dep_x - (d - T)] 中的树杈。
显然,这些树杈我们可以
对于最大深度
对于每个树杈 y 记录一个 z,表示为深度最大的满足
显然,枚举到的 z 一定覆盖了所有最大深度
这一部分复杂度也是
然后就得到了一个
于是我们考虑用
由于分界点两侧都用到了这个查询,所以没有办法平衡复杂度。
容易发现,当
作者在多次提交后发现,将 T 和 B 调到非常小时程序效率有显著提高。
这是因为,由于数据中树的形态比较集中,大部分树链都非常大,向上跳的次数远小于
经过测试,T = 30 效率最佳。
注意到第一次根号分治本质上没有多少用处,但是可以有效帮助我们处理边界情况,降低思维难度。
代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>
#include <set>
#include <cstdlib>
#include <cmath>
#include <deque>
using namespace std;
const int N = 1e5 + 10;
#define mp make_pair
#define pii pair<int, int>
int read()
{
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return x * f;
}
int T, tot, n, m, t, cnt;
int head[N], ver[N * 2], last[N * 2], f[N][30], siz[N], pre[N], l[N], r[N], dep[N];
int a[N], ans[N], vis[N], bac[N];
vector<int> v[N];
int g[N];
void add(int x, int y)
{
ver[++tot] = y;
last[tot] = head[x];
head[x] = tot;
}
int find(int x)
{
int S = siz[x];
for (int j = 20; j >= 0; j--)
if (f[x][j] && siz[f[x][j]] - S < t)
x = f[x][j];
return x;
}
int find2(int x, int y)
{
int D = dep[x];
for (int j = 20; j >= 0; j--)
if (f[x][j] && D - dep[f[x][j]] < y)
x = f[x][j];
return x;
}
void dfs(int x, int fa)
{
siz[x] = 1;
l[x] = ++cnt;
bac[cnt] = x;
dep[x] = dep[fa] + 1;
for (int i = head[x]; i; i = last[i])
{
int y = ver[i];
if (y == fa)
continue;
f[y][0] = x;
for (int j = 1; j <= 20; j++)
f[y][j] = f[f[y][j - 1]][j - 1];
dfs(y, x);
siz[x] += siz[y];
}
r[x] = l[x] + siz[x] - 1;
}
void init()
{
tot = cnt = 0;
for (int i = 1; i <= n; i++)
head[i] = 0;
for (int i = 1; i <= n; i++)
for (int j = 0; j <= 20; j++)
f[i][j] = 0;
for (int i = 1; i <= n; i++)
a[i] = 0, ans[i] = 0, bac[i] = 0;
for (int i = 1; i <= n; i++)
v[i].clear();
}
int query(int num, int l, int r, int d, int xs)
{
int L = lower_bound(v[d].begin(), v[d].end(), l) - v[d].begin(), R = upper_bound(v[d].begin(), v[d].end(), r) - v[d].begin() - 1;
return R - L + 1;
}
int ask(int x, int y, int z, int d, int num)
{
int cha = dep[y] - dep[x], res = 0;
res += query(num, l[x], r[x], d - cha + dep[x], 1);
if (z)
res -= query(num, l[z], r[z], d - cha + dep[x], -1);
return res;
}
int main()
{
T = read();
while (T--)
{
n = read(), m = read();
init();
t = 30;
for (int i = 1; i < n; i++)
{
int u = read(), v = read();
add(u, v);
add(v, u);
}
dfs(1, 0);
for (int i = 1; i <= n; i++)
pre[i] = find(i);
for (int i = 1; i <= n; i++)
v[dep[bac[i]]].push_back(i);
for (int i = 1; i <= m; i++)
{
int x = read(), d = read(), res = 0;
res += ask(x, x, 0, d, i);
if (d <= t)
{
int cur = x;
for (int j = 1; j <= t; j++)
{
if (cur == 1)
break;
res += ask(f[cur][0], x, cur, d, i);
cur = f[cur][0];
}
}
else
{
int cur = find2(x, d - t), cntt = 0;
while (cur != 1 && dep[x] - dep[cur] <= d)
{
vis[f[cur][0]] = 1;
g[++cntt] = f[cur][0];
res += ask(f[cur][0], x, cur, d, i);
cur = f[cur][0];
}
cur = x;
while (cur != 1)
{
cur = pre[cur];
if (cur == 1)
break;
if (vis[f[cur][0]])
{
cur = f[cur][0];
continue;
}
res += ask(f[cur][0], x, cur, d, i);
cur = f[cur][0];
}
for (int i = 1; i <= cntt; i++)
vis[g[i]] = 0;
}
printf("%d\n", res);
}
}
return 0;
}