ABC401F题解

· · 题解

首先找到第一棵树的直径 d_1 和第二棵树的直径 d_2
不难发现,连接第一棵树的 i 结点和第二棵树的 j 结点后,f(i, j) 就等于 d_1d_2 和一条经过了边 (i,j) 的路径的长度。这条路径一定可以被分成三部分:

  1. (i, j)

则将这三部分的长度加起来就得到了 f(i, j)

我们记 i 到第一棵树中离 i 最远一点的路径长度为 a_i,记 j 到第二棵树中离 j 最远一点的路径长度为 b_j,则可写出 f(i, j) = \max(\max(d_1, d_2), a_i + b_j + 1)

计算直径的长度非常简单,先随便以一个点为起点,做一遍 bfs,找到离这个起点最远的点,这个点一定是直径的一个端点。然后再以这个端点为起点,做一遍 bfs,离这个端点最远的点一定是直径的另一个端点。

然后我们考虑如何计算出 ab。在计算直径时,我们可以发现,端点中的一个一定是离这个点的最远的一个点,我们可以在 dfs 过程中记录这个值。注意两个端点都要当一次起点。

然后我们考虑计算所有的 f(i, j) 的值。发现如果 a_i + b_j + 1 < \max(d_1, d_2),则 f(i, j) = \max(d_1, d_2)

可以将 a 按从小到大排序,将 b 从大到小排序。从 1n_2 枚举 j,然后用一个变量记录第一个 i 使得 a_i + b_j + 1 \ge \max(d_1, d_2) 的值,则此时对于所有 k(1 \le k < i),有 a_k + b_j + 1 < \max(d_1, d_2),则 f(i, j) = \max(d_1, d_2)。而所有 k(i \le k \le n_1),有 a_k + b_j + 1 \ge \max(d_1, d_2),则 f(i, j) = a_k + b_j + 1。那么 b_j 对于所有 a 的贡献就是 (i-1) \times \max(d_1, d_2) + (n_1 - i + 1) \times (b_i) + i + \displaystyle\sum_{k=i}^{n_1} a_k

显然式子的最后一项可以用后缀和优化掉。并且 i 对于 b_j 的减小是递增的,则可以用一个类似于滑动窗口的东西求出来。计算的时间复杂度是 O(n) 的。

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 2e5+5;

int n1, n2;
int d1, d2, dmax;
int dis1[MAXN], dis2[MAXN];
int d11, d12, d21, d22;
int a[MAXN], b[MAXN];
vector<int> e1[MAXN], e2[MAXN];
ll ans;
ll sa[MAXN];
//e:当前 dfs 的树。
//d:起点到 u 的距离。
//maxd:上文中的 a 或者 b。
//c:求出的端点。
void dfs(int u, int fa, vector<int> *e, int *d, int *maxd, int *c)
{
    d[u] = d[fa]+1;
    maxd[u] = max(maxd[u], d[u]);
    if(d[u] > d[*c]) *c = u;
    for(auto v : e[u])
    {
        if(v == fa) continue;
        dfs(v, u, e, d, maxd, c);
    }
}

int main()
{
    scanf("%d", &n1);
    for(int i = 1;i < n1;i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        e1[u].push_back(v);
        e1[v].push_back(u);
    }
    dis1[0] = -1;
    dfs(1, 0, e1, dis1, a, &d11);
    dfs(d11, 0, e1, dis1, a, &d12);
    d1 = dis1[d12];
    dfs(d12, 0, e1, dis1, a, &d11);//记住第二个端点也要 dfs 一便。

    scanf("%d", &n2);
    for(int i = 1;i < n2;i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        e2[u].push_back(v);
        e2[v].push_back(u);
    }
    dis2[0] = -1;
    dfs(1, 0, e2, dis2, b, &d21);
    dfs(d21, 0, e2, dis2, b, &d22);
    d2 = dis2[d22];
    dfs(d22, 0, e2, dis2, b, &d21);

    dmax = max(d1, d2);

    sort(a + 1, a + n1 + 1); a[n1+1] = 0x3f3f3f3f; a[0] = INT_MIN;//a[n1+1] 赋值是为了防止下面 cur 不断的加。
    sort(b + 1, b + n2 + 1, greater<int>() );

    for(int i = n1;i >= 1;i--) sa[i] = sa[i+1] + a[i];
    //求出 a 的后缀和 sa.
    for(int i = 1, cur = 0;i <= n2;i++)
    {
        while(a[cur] + b[i] + 1 <= dmax) cur++;
        //类似于滑动窗口的东西
        ans += (ll)(cur - 1LL) * (ll)dmax + (ll)(n1 - cur + 1LL) * (ll)(b[i] + 1LL) + sa[cur];
    }
    printf("%lld\n", ans);

    return 0;
}