【贡献法、树上倍增、MEX、分类讨论】ABC438F

· · 题解

【贡献法、树上倍增、MEX、分类讨论】ABC438F

fix 了一个笔误: 将 f(i, j) \ge k - 1 改成了 f(i, j) \ge k + 1,用 f(i, j) \ge k 的对数减去 f(i, j) \ge k + 1 的对数才是 f(i, j) = k 的对数。

原题链接

题意

给你一棵有 N 个顶点的树 T 。顶点编号从 0N-1 , 第 i 条边 (1\le i\le N-1) 双向连接顶点 u_iv_i ,注意点从 0 开始编号。

对于一对整数 (i,j)0 \leq i,j \lt N ,定义 f(i,j) 如下:

请注意,树 T 中从顶点 i 到顶点 j 的路径包括顶点 ij

\sum_{0\le i \le j \lt N} f(i,j) 的值。

思路

赛时猛干 1 小时才过,被树论做局了。

待求式转化

这种题一般要用贡献法才能高效解决,所以很自然的考虑某个编号 k 贡献了多少次,这就需要我们统计 f(i, j) = k 的点对数量。一条路径上的点的编号的 mex 恰好等于 k,当且仅当这条路径上必须出现 0k - 1 之间的每个编号,并且 k 不能出现在这上面。这个不包含 k 很恶心,我想了一会儿感觉不太会处理。

这时可能会想到前缀和,把统计 f(i, j) = k 的点对数量,改成统计 f(i, j) \ge k 的点对数量,这样只要知道 f(i, j) \ge k + 1f(i, j) \ge k 分别的点对数,后者减去前者其实就是恰好等于 k 的了。并且,统计 f(i, j) \ge k 的点对,只需要满足 0k - 1 必须在 ij 的路径上就行,这个相对来说容易一些。

不妨记 c(k) 表示 f(i, j) \ge k 的点对数量,则可以对待求式做一个变形:

\sum_{0\le i \le j \lt N} f(i,j) = \sum_{k = 0}^{n}k\times (c(k) - c(k + 1))

我们展开写一下,会发现:

\sum_{k = 0}^{n}k\times (c(k) - c(k + 1)) = c(0) - c(1) + 2\times(c(1) - c(2)) + 3 \times (c(2) - c(3))... + n \times (c(n) - c(n + 1))\\ = c(0) + c(1) + ... + c(n) - n \times c(n + 1)

c(n + 1) = 0,所以原题其实就是要求:

\sum_{k = 0}^nc(k)

维护经过 0 到 k - 1 的路径两端

在求某个 c(k) 之前,我们首先需要看一下能覆盖 0k - 1 这些点的路径的情况:

那么,如何维护能覆盖 0k - 1 这些点的链的两端呢?假设 0k - 2 在某条链上,链两端的点是 uv,现在考虑加入 k - 1 后如何维护。这个需要分情况讨论:

计数

好,现在我们会维护链的两端 uv 了,但具体在算上面提到的点对数 cnt_1 \times cnt_2 时,应该怎么计数?我们继续进行分类讨论:

对于 u = v = 0 的初始情况,还需要特判一下这种情况的点对数。其实就是所有点对数减去不经过 0 的点对数,而不经过 0 的点对数,就是 0 的孩子的子树内部选两个点。

代码

#include <bits/stdc++.h>

using namespace std;

typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;

const int N = 2e5 + 10; 
const int M = 20;
const int MOD1 = 1e9 + 7;
const int MOD2 = 998244353;
const int INF = 0x3f3f3f3f;
const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const double eps = 1e-6;

int fa[N][M];
int n, lg[N], depth[N];
LL sz[N];
vector<vector<int>> e(N);

// 倍增 LCA 预处理
void pre(int u, int f, int d) {
    depth[u] = d;
    fa[u][0] = f;
    for (int i = 1; i <= lg[d]; i++) {
        if (fa[u][i - 1] != -1) {
            fa[u][i] = fa[fa[u][i - 1]][i - 1];
        } else {
            fa[u][i] = -1;
        }
    }

    sz[u] = 1;

    for (auto v : e[u]) {
        if (v != f) {
            pre(v, u, d + 1);
            sz[u] += sz[v];
        }
    }
}

int lca(int x, int y) {
    if (depth[x] < depth[y]) {
        swap(x, y);
    }

    while (depth[x] > depth[y]) {
        x = fa[x][lg[depth[x] - depth[y]]];
    }

    if (x == y) {
        return x;
    }

    for (int i = lg[depth[x]]; i >= 0; i--) {
        if (fa[x][i] != fa[y][i]) {
            x = fa[x][i];
            y = fa[y][i];
        }
    }
    return fa[x][0];
}

// 倍增找 x 的 k 级祖先
int get_kth_anc(int x, int k) {
    for (int i = M - 1; i >= 0; i--) {
        if ((k >> i) & 1) {
            x = fa[x][i];
        }
    }
    return x;
}

// 求 x 和 y 在树上的距离
LL get_dist(int x, int y) {
    int anc = lca(x, y);
    return depth[x] + depth[y] - 2 * depth[anc];
}

// 判断 cur 点是否在 u 到 v 的路径上
bool exist(int u, int v, int cur) {
    return get_dist(u, v) == get_dist(u, cur) + get_dist(cur, v);
}

LL C(LL n) {
    return n * (n + 1) / 2;
}

// 计算 u 这边的点乘以 v 这边的点的对数
LL calc(int u, int v) {
    int anc = lca(u, v);
    if (u == v) {
        // 初始情况,需要减去 0 的所有孩子 t 为根的子树内部两两组合的对数
        assert(u == 0);
        LL res = C(n);
        for (auto t : e[0]) {
            if (t != -1) {
                res -= C(sz[t]);
            }
        }
        return res;
    } else {
        if (anc == u) {
            // u 是 v 的祖先
            int d = depth[v] - depth[u] - 1;
            int t = get_kth_anc(v, d);
            return sz[v] * (n - sz[t]);
        } else if (anc == v) {
            int d = depth[u] - depth[v] - 1;
            int t = get_kth_anc(u, d);
            return sz[u] * (n - sz[t]);
        } else {
            return sz[u] * sz[v];
        }
    }
}

void solve() {
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
    }

    for (int i = 2; i <= n; i++) {
        lg[i] = lg[i / 2] + 1;
    }

    pre(0, -1, 1);

    int u = 0, v = 0;
    // 先把只有一个点的情况算了
    LL res = calc(u, v);
    for (int k = 2; k <= n; k++) {
        // t 是我们当前要新加入的编号
        int t = k - 1;
        if (exist(u, v, t)) {
            // t 本来就在 u 到 v 的路径上了,则不用更新链端点
        } else if (exist(t, u, v)) {
            // v 在 t 到 u 的路径上,说明 v 应该更新成 t
            v = t;
        } else if (exist(t, v, u)) {
            u = t;
        } else {
            break;
        }
        res += calc(u, v);
    }
    cout << res << "\n";
}

int main() {
    #ifdef LOCAL_TEST
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #endif

    ios::sync_with_stdio(false);
    cin.tie(0);

    int T = 1;
    // cin >> T;
    while (T--) {
        solve();
    }

    return 0;
}