题解:CF2063E Triangle Tree
hanhoudedidue · · 题解
分析
首先注意到:给定两条边长分别为
再来看题目所求,对于任意两个点,求以这两个点到它们的 LCA 的距离为两条边构成的三角形个数。发现和 LCA 有关,故考虑树上启发式合并。在 LCA 处计算贡献,在枚举该节点其中一棵子树的所有节点时,分两种情况考虑:
(一):这个节点到当前 LCA 的距离作为最小值时,其贡献即为其它子树内(到当前 LCA 的)距离大于该节点的距离的个数乘上两倍该距离减一。
(二):这个节点到当前 LCA 的距离作为较大值时,贡献为两倍其它子树内(到当前 LCA 的)距离小于该节点的距离总和减去其它子树内(到当前 LCA 的)距离小于该节点的个数。
使用树状树组维护即可。时间复杂度
code:
#include<bits/stdc++.h>
#define int long long
#define ls(x) ((x)<<1)
#define rs(x) ((x)<<1|1)
#define mid ((l+r)>>1)
#define lowbit(x) ((x)&(-x))
using namespace std;
const int N=6e5+5;
int n,m,T;
struct edge{
int to,nxt;
}a[N];
int h[N],cnt,siz[N],son[N];
void add(int u,int v){
a[++cnt].to=v;
a[cnt].nxt=h[u];
h[u]=cnt;
}
int ans;
int L[N],R[N],rk,rev[N],dep[N];
struct BIT{
int t[N];
void add(int x,int sum){
if(!x) return;
for(;x<=n;x+=lowbit(x)) t[x]+=sum;
}
int query(int x){
int sum=0;
for(;x;x-=lowbit(x)) sum+=t[x];
return sum;
}
}b1,b2;
void dfs1(int u,int fa){
siz[u]=1;L[u]=++rk;rev[rk]=u;
dep[u]=dep[fa]+1;
for(int i=h[u];i;i=a[i].nxt){
int v=a[i].to;if(v==fa) continue;
dfs1(v,u);siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
R[u]=rk;
return;
}
void Add(int x){
b1.add(x,x);
b2.add(x,1);
}
void Del(int x){
b1.add(x,-x);
b2.add(x,-1);
}
int Query(int x,int rt){
int res=0;
res+=(b2.query(n)-b2.query(x))*(2*(x-rt)-1);//dis=x-rt
res+=(2*b1.query(x)-2*b2.query(x)*rt-b2.query(x));//n*(2*(down-rt)-1)
return res;
}
void dfs2(int u,int fa,bool keep){
for(int i=h[u];i;i=a[i].nxt){
int v=a[i].to;if(v==son[u]||v==fa) continue;
dfs2(v,u,0);
}
if(son[u]) dfs2(son[u],u,1);
int kk=ans;
for(int i=h[u];i;i=a[i].nxt){
int v=a[i].to;if(v==fa||v==son[u]) continue;
for(int j=L[v];j<=R[v];j++) ans+=Query(dep[rev[j]],dep[u]);
for(int j=L[v];j<=R[v];j++) Add(dep[rev[j]]);
}
// cerr<<u<<':'<<ans-kk<<'\n';
Add(dep[u]);
if(!keep){
for(int i=L[u];i<=R[u];i++) Del(dep[rev[i]]);
}
}
void init(){
for(int i=1;i<=n;i++) b1.t[i]=b2.t[i]=0;
for(int i=1;i<=n;i++) L[i]=R[i]=rev[i]=dep[i]=0;
for(int i=1;i<=n;i++) a[i]=edge{0,0},h[i]=siz[i]=son[i]=0;
rk=cnt=ans=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>T;
while(T--){
init();
cin>>n;
for(int i=1,u,v;i<n;i++){
cin>>u>>v;
add(u,v);add(v,u);
}
dfs1(1,0);
dfs2(1,0,1);
cout<<ans<<'\n';
}
return 0;
}