全新做法!P4103 dsu on tree 题解

· · 题解

题目链接

这题有个弱化版是 ABC359G,笔者没学过虚树,在赛时用树上启发式合并(dsu on tree)过了,后来发现这题也能这样做,于是写了本篇题解做分享。

首先,我们可以把操作离线下来,挂到每个点上,问题就变成了每个点有一些属性,总属性数量是 O(n) 的,然后求每种属性的所有点的两两距离之和、最大值和最小值。

路径统计考虑枚举 lca。枚举到 u 作为 lca 时,我们逐个遍历它的子树,先利用桶内信息计算当前子树的贡献,再把当前子树的信息添加到桶内。最后别忘了统计 u 和子树内其他点对答案的贡献。

具体地,我们用 sd_i 存储属性为 i 的点的深度之和,mxd_i 存储最大深度,mnd_i 存储最小深度,cnt_i 存储节点个数。那么假设现在枚举 u 为 lca,统计到了点 j 的属性 i,则对距离和的贡献为 sd_i+cnt_i(dep_j-2dep_u),对最大距离的贡献为 mxd_i+dep_j-2dep_u,对最小距离的贡献为 mnd_i+dep_j-2dep_u

加入一个点时也类似,根据定义更新信息即可。

暴力做是 O(n^2) 的,加上 dsu on tree 就是 O(n\log n) 的了。

具体地,因为一个子树的计算量是点数 + 每个点的属性个数总和,我们根据这个计算量,求出每个点的重儿子,再 dsu on tree 就对了。

总时间复杂度 O(q+n\log n)。跑的蛮快的。

删除一个子树的信息时,要把涉及到的信息直接初始化掉,不然最大最小值没法处理。

#include <iostream>
#include <vector>

using namespace std;

const int N = 1e6 + 5;

int n, q, sz[N], son[N], dep[N];
vector<int> e[N], qry[N];
int mn[N], mx[N], mxd[N], mnd[N], cnt[N];
long long sum[N], sd[N];

void dfs(int u, int fa) // 找出重儿子,预处理深度
{
    sz[u] = 1 + qry[u].size(), dep[u] = dep[fa] + 1;
    for (auto j : e[u])
        if (j ^ fa)
        {
            dfs(j, u), sz[u] += sz[j];
            if (sz[j] > sz[son[u]])
                son[u] = j;
        }
}

void add(int u, int fa) // 添加子树信息
{
    for (auto i : qry[u])
        mxd[i] = max(mxd[i], dep[u]), mnd[i] = min(mnd[i], dep[u]), sd[i] += dep[u], cnt[i]++;
    for (auto j : e[u])
        if (j ^ fa)
            add(j, u);
}

void del(int u, int fa) // 清除子树信息
{
    for (auto i : qry[u])
        mxd[i] = -1e9, mnd[i] = 1e9, sd[i] = cnt[i] = 0; // 直接初始化掉
    for (auto j : e[u])
        if (j ^ fa)
            del(j, u);
}

void cal(int u, int fa, int lca) // 计算子树贡献
{
    for (auto i : qry[u])
        sum[i] += sd[i] + cnt[i] * (dep[u] - 2ll * lca), mx[i] = max(mx[i], mxd[i] + dep[u] - 2 * lca),
                                                         mn[i] = min(mn[i], mnd[i] + dep[u] - 2 * lca);
    for (auto j : e[u])
        if (j ^ fa)
            cal(j, u, lca);
}

void dsu(int u, int fa) // 树上启发式合并
{
    for (auto j : e[u])
        if (j ^ fa && j ^ son[u])
            dsu(j, u), del(j, u);
    if (son[u])
        dsu(son[u], u);
    for (auto j : e[u])
        if (j ^ fa && j ^ son[u])
            cal(j, u, dep[u]), add(j, u);
    for (auto i : qry[u])
        sum[i] += sd[i] - 1ll * cnt[i] * dep[u], mx[i] = max(mx[i], mxd[i] - dep[u]),
                                                 mn[i] = min(mn[i], mnd[i] - dep[u]);
    for (auto i : qry[u])
        mxd[i] = max(mxd[i], dep[u]), mnd[i] = min(mnd[i], dep[u]), sd[i] += dep[u], cnt[i]++;
}

int main()
{
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n;
    for (int i = 1, a, b; i < n; i++)
        cin >> a >> b, e[a].push_back(b), e[b].push_back(a);
    cin >> q;
    for (int i = 1, k, p; i <= q; i++)
    {
        cin >> k, mn[i] = mnd[i] = 1e9, mx[i] = mxd[i] = -1e9; // 记得初始化
        while (k--)
            cin >> p, qry[p].push_back(i);
    }
    dfs(1, 0), dsu(1, 0);
    for (int i = 1; i <= q; i++)
        cout << sum[i] << " " << mn[i] << " " << mx[i] << "\n";
    return 0;
}