【贡献法、树上倍增、MEX、分类讨论】ABC438F
【贡献法、树上倍增、MEX、分类讨论】ABC438F
fix 了一个笔误: 将
f(i, j) \ge k - 1 改成了f(i, j) \ge k + 1 ,用f(i, j) \ge k 的对数减去f(i, j) \ge k + 1 的对数才是f(i, j) = k 的对数。
原题链接
题意
给你一棵有
对于一对整数
- 在树
T 中,从顶点i 到顶点j 的路径上未包含的顶点中编号最小的顶点的顶点编号。- 在这里,如果从顶点
i 到顶点j 的路径包含了从顶点0 到顶点N-1 的所有顶点,则设为f(i,j)=N 。
- 在这里,如果从顶点
请注意,树
求
思路
赛时猛干 1 小时才过,被树论做局了。
待求式转化
这种题一般要用贡献法才能高效解决,所以很自然的考虑某个编号
这时可能会想到前缀和,把统计
不妨记
我们展开写一下,会发现:
而
维护经过 0 到 k - 1 的路径两端
在求某个
- 如果
0 到k - 1 这些点恰好同时在树的某条链上,并且这个链的两端u 和v 也都在[0, k - 1] 范围内,则这时说明存在f(i, j) \ge k 的点对,点对个数就是u 那端的部分的大小,乘以v 那端的部分的大小,即图中的cnt_1 \times cnt_2 。 - 如果不存在一条链,使得
0 到k - 1 都在这条链上,说明必须是一个分叉的结构才能覆盖这些点,也就说明了不存在f(i, j) \ge k 的点对了,自然也不用往后看更大的k 了。
那么,如何维护能覆盖
- 如果
k - 1 在u 到v 的路径上,则链不变。这个事情可以使用距离去判断,只要dist(u, k - 1) + dist(k - 1, v) = dist(u, v) ,则说明k - 1 在u 到v 的路径上。 - 如果
k - 1 不在u 到v 的路径上,则要考虑k - 1 是在u 那边(黄圈里),还是在v 那边(红圈里),或者是在其他地方(中间蓝圈里):- 如果
k - 1 在u 那边(黄圈里),那么显然存在k - 1 到u 这条路径,再拼上从u 到v 的路径,这样得到一个新的链,其两端是k - 1 和v ,这条链上包含0 到k - 1 的所有点,我们应该把u 更新成k - 1 。换句话说,此时u 在k - 1 到v 这条链上,我们可以用dist(k - 1, u) + dist(u, v) = dist(k - 1, v) 去判断。 - 如果
k - 1 在v 那边(红圈里),和上一种情况是类似的,不再赘述。 - 否则,就是在蓝圈里,说明一条链覆盖不了
0 到k - 1 这些点了。
- 如果
计数
好,现在我们会维护链的两端
- 如果
u 是v 的祖先,则如下图所示,我们需要求u 的能走到v 的那个直接儿子t ,那么cnt_2 = size_v ,cnt_1 = n - size_t ,其中size_v 指的是以0 为树根时,v 子树的大小。t 怎么找呢?其实t 就是v 的depth_v - depth_u - 1 级祖先,倍增找就行了 - 如果
v 是u 的祖先,和上一种情况类似,不再赘述。 - 如果
u 和v 的LCA 是anc ,则如图所示,其实点对数就是size_u \times size_v 了。
对于
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
const int N = 2e5 + 10;
const int M = 20;
const int MOD1 = 1e9 + 7;
const int MOD2 = 998244353;
const int INF = 0x3f3f3f3f;
const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const double eps = 1e-6;
int fa[N][M];
int n, lg[N], depth[N];
LL sz[N];
vector<vector<int>> e(N);
// 倍增 LCA 预处理
void pre(int u, int f, int d) {
depth[u] = d;
fa[u][0] = f;
for (int i = 1; i <= lg[d]; i++) {
if (fa[u][i - 1] != -1) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
} else {
fa[u][i] = -1;
}
}
sz[u] = 1;
for (auto v : e[u]) {
if (v != f) {
pre(v, u, d + 1);
sz[u] += sz[v];
}
}
}
int lca(int x, int y) {
if (depth[x] < depth[y]) {
swap(x, y);
}
while (depth[x] > depth[y]) {
x = fa[x][lg[depth[x] - depth[y]]];
}
if (x == y) {
return x;
}
for (int i = lg[depth[x]]; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
// 倍增找 x 的 k 级祖先
int get_kth_anc(int x, int k) {
for (int i = M - 1; i >= 0; i--) {
if ((k >> i) & 1) {
x = fa[x][i];
}
}
return x;
}
// 求 x 和 y 在树上的距离
LL get_dist(int x, int y) {
int anc = lca(x, y);
return depth[x] + depth[y] - 2 * depth[anc];
}
// 判断 cur 点是否在 u 到 v 的路径上
bool exist(int u, int v, int cur) {
return get_dist(u, v) == get_dist(u, cur) + get_dist(cur, v);
}
LL C(LL n) {
return n * (n + 1) / 2;
}
// 计算 u 这边的点乘以 v 这边的点的对数
LL calc(int u, int v) {
int anc = lca(u, v);
if (u == v) {
// 初始情况,需要减去 0 的所有孩子 t 为根的子树内部两两组合的对数
assert(u == 0);
LL res = C(n);
for (auto t : e[0]) {
if (t != -1) {
res -= C(sz[t]);
}
}
return res;
} else {
if (anc == u) {
// u 是 v 的祖先
int d = depth[v] - depth[u] - 1;
int t = get_kth_anc(v, d);
return sz[v] * (n - sz[t]);
} else if (anc == v) {
int d = depth[u] - depth[v] - 1;
int t = get_kth_anc(u, d);
return sz[u] * (n - sz[t]);
} else {
return sz[u] * sz[v];
}
}
}
void solve() {
cin >> n;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 2; i <= n; i++) {
lg[i] = lg[i / 2] + 1;
}
pre(0, -1, 1);
int u = 0, v = 0;
// 先把只有一个点的情况算了
LL res = calc(u, v);
for (int k = 2; k <= n; k++) {
// t 是我们当前要新加入的编号
int t = k - 1;
if (exist(u, v, t)) {
// t 本来就在 u 到 v 的路径上了,则不用更新链端点
} else if (exist(t, u, v)) {
// v 在 t 到 u 的路径上,说明 v 应该更新成 t
v = t;
} else if (exist(t, v, u)) {
u = t;
} else {
break;
}
res += calc(u, v);
}
cout << res << "\n";
}
int main() {
#ifdef LOCAL_TEST
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
ios::sync_with_stdio(false);
cin.tie(0);
int T = 1;
// cin >> T;
while (T--) {
solve();
}
return 0;
}