题解 CF2063E Triangle Tree
cjh20090318
·
·
题解
年级里面有大佬用神奇启发式合并写的,还好我不会这么复杂的算法。
题意
给定一棵 n 个点的树。
设函数 $f(u,v)$ 表示:
- 当 $u$ 不为 $v$ 的祖先,且 $v$ 不为 $u$ 的祖先时,存在多少个整数 $x$ 使得边长为 $\operatorname{dist}(u,\operatorname{lca}(u,v))$ 、 $\operatorname{dist}(v,\operatorname{lca}(u,v))$、$x$ 能成为一个三角形。
- 否则函数值为 $0$。
最后需要求出:
$$
\sum_{i = 1}^{n-1} \sum_{j = i+1}^n f(i,j)
$$
## 分析
根据三角形边长的限制,可以得到:
$$
\vert \operatorname{dist}(u,\operatorname{lca}(u,v)) - \operatorname{dist}(v,\operatorname{lca}(u,v)) \vert < x < \operatorname{dist}(u,\operatorname{lca}(u,v)) + \operatorname{dist}(v,\operatorname{lca}(u,v))
$$
简单来说两边之和大于第三边,两边之差小于第三边。
把距离函数拆成深度,设 $d_u$ 为 $u$ 的深度,$lca = \operatorname{lca}(u,v)$,那么以上式子可以表示为:
$$
\vert d_u - d_v \vert < x < d_u + d_v - 2d_{lca}
$$
然后就可以表示出 $f(u,v)$:
$$
f(u,v) = (d_u + d_v - 2d_{lca}) - \vert d_u -d_v \vert - 1
$$
后面绝对值这一部分是 [[ABC186D] Sum of difference](https://www.luogu.com.cn/problem/AT_abc186_d),可以把每个深度的个数加入到一个桶中,统计有多少个数比它小或大。
$d_u$ 和 $d_v$ 的求和也是简单的,对于每个点 $u$ 会有 $n-1$ 个询问和他有关,所以对答案的贡献即为 $(n-1)d_u$。
现在考虑这个 $-2d_{lca}$ 如何处理,需要求出有多少对点的最近公共祖先为 $lca$,显然在 $lca$ 两个不同子树内的任意点的最近公共祖先都是它。
设 $sz_u$ 表示 $u$ 的子树大小,$son_u$ 表示 $u$ 的儿子集合,存在点对数即为:
$$
\dfrac 1 2 \sum\limits_{u \in son_{lca}} \sum\limits_{v \in son_{lca} \land v \neq u} sz_u \times sz_v
$$
这个东西可以前缀和优化做到线性。
最后对于每一个询问还需要减 $1$,然后会发现对于$u$ 为 $v$ 的祖先或 $v$ 为 $u$ 的祖先的情况不应该多减,所以加上这一部分就好。
时间复杂度和空间复杂度均为 $O(n)$。
## 代码
```cpp
//the code is from chenjh
#include<bits/stdc++.h>
#define MAXN 300003
using namespace std;
typedef long long LL;
int n;
vector<int> G[MAXN];
LL ans1=0,ans2=0,ans3=0;
int dep[MAXN],sz[MAXN],a[MAXN];
void dfs(const int u,const int FA){
++a[dep[u]=dep[FA]+1],sz[u]=1;
ans1+=(n-1ll)*dep[u];
int x=0;
LL y=0;//前缀和统计 LCA 为当前点的点对个数。
for(const int v:G[u])if(v!=FA){
dfs(v,u),sz[u]+=sz[v];
y+=(LL)sz[v]*x,x+=sz[v];
}
ans1-=2*(y+sz[u]-1)*dep[u],ans3+=sz[u]-1;//减去 2d_{lca},排除祖先的情况。
}
void solve(){
scanf("%d",&n);
for(int i=1;i<=n;i++) G[i].clear(),a[i]=0;
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
G[u].push_back(v),G[v].push_back(u);
}
ans1=ans2=ans3=0;
dfs(1,0);
for(int i=1,x=0;i<=n;x+=a[i++]) ans2+=(LL)i*a[i]*(x-(n-a[i]-x));//处理绝对值部分。
printf("%lld\n",ans1-ans2+ans3-n*(n-1ll)/2);
}
int main(){
int T;scanf("%d",&T);
while(T--) solve();
return 0;
}
```