题解 CF2063E Triangle Tree

· · 题解

年级里面有大佬用神奇启发式合并写的,还好我不会这么复杂的算法。

题意

给定一棵 n 个点的树。

设函数 $f(u,v)$ 表示: - 当 $u$ 不为 $v$ 的祖先,且 $v$ 不为 $u$ 的祖先时,存在多少个整数 $x$ 使得边长为 $\operatorname{dist}(u,\operatorname{lca}(u,v))$ 、 $\operatorname{dist}(v,\operatorname{lca}(u,v))$、$x$ 能成为一个三角形。 - 否则函数值为 $0$。 最后需要求出: $$ \sum_{i = 1}^{n-1} \sum_{j = i+1}^n f(i,j) $$ ## 分析 根据三角形边长的限制,可以得到: $$ \vert \operatorname{dist}(u,\operatorname{lca}(u,v)) - \operatorname{dist}(v,\operatorname{lca}(u,v)) \vert < x < \operatorname{dist}(u,\operatorname{lca}(u,v)) + \operatorname{dist}(v,\operatorname{lca}(u,v)) $$ 简单来说两边之和大于第三边,两边之差小于第三边。 把距离函数拆成深度,设 $d_u$ 为 $u$ 的深度,$lca = \operatorname{lca}(u,v)$,那么以上式子可以表示为: $$ \vert d_u - d_v \vert < x < d_u + d_v - 2d_{lca} $$ 然后就可以表示出 $f(u,v)$: $$ f(u,v) = (d_u + d_v - 2d_{lca}) - \vert d_u -d_v \vert - 1 $$ 后面绝对值这一部分是 [[ABC186D] Sum of difference](https://www.luogu.com.cn/problem/AT_abc186_d),可以把每个深度的个数加入到一个桶中,统计有多少个数比它小或大。 $d_u$ 和 $d_v$ 的求和也是简单的,对于每个点 $u$ 会有 $n-1$ 个询问和他有关,所以对答案的贡献即为 $(n-1)d_u$。 现在考虑这个 $-2d_{lca}$ 如何处理,需要求出有多少对点的最近公共祖先为 $lca$,显然在 $lca$ 两个不同子树内的任意点的最近公共祖先都是它。 设 $sz_u$ 表示 $u$ 的子树大小,$son_u$ 表示 $u$ 的儿子集合,存在点对数即为: $$ \dfrac 1 2 \sum\limits_{u \in son_{lca}} \sum\limits_{v \in son_{lca} \land v \neq u} sz_u \times sz_v $$ 这个东西可以前缀和优化做到线性。 最后对于每一个询问还需要减 $1$,然后会发现对于$u$ 为 $v$ 的祖先或 $v$ 为 $u$ 的祖先的情况不应该多减,所以加上这一部分就好。 时间复杂度和空间复杂度均为 $O(n)$。 ## 代码 ```cpp //the code is from chenjh #include<bits/stdc++.h> #define MAXN 300003 using namespace std; typedef long long LL; int n; vector<int> G[MAXN]; LL ans1=0,ans2=0,ans3=0; int dep[MAXN],sz[MAXN],a[MAXN]; void dfs(const int u,const int FA){ ++a[dep[u]=dep[FA]+1],sz[u]=1; ans1+=(n-1ll)*dep[u]; int x=0; LL y=0;//前缀和统计 LCA 为当前点的点对个数。 for(const int v:G[u])if(v!=FA){ dfs(v,u),sz[u]+=sz[v]; y+=(LL)sz[v]*x,x+=sz[v]; } ans1-=2*(y+sz[u]-1)*dep[u],ans3+=sz[u]-1;//减去 2d_{lca},排除祖先的情况。 } void solve(){ scanf("%d",&n); for(int i=1;i<=n;i++) G[i].clear(),a[i]=0; for(int i=1,u,v;i<n;i++){ scanf("%d%d",&u,&v); G[u].push_back(v),G[v].push_back(u); } ans1=ans2=ans3=0; dfs(1,0); for(int i=1,x=0;i<=n;x+=a[i++]) ans2+=(LL)i*a[i]*(x-(n-a[i]-x));//处理绝对值部分。 printf("%lld\n",ans1-ans2+ans3-n*(n-1ll)/2); } int main(){ int T;scanf("%d",&T); while(T--) solve(); return 0; } ```