题解:CF2063E Triangle Tree

· · 题解

分析

首先注意到:给定两条边长分别为 a,b(a<b) 的三角形,另外一条边长 c 显然有 b-a<c<b+a,取值便为 2a-1。即最小值的两倍减一。

再来看题目所求,对于任意两个点,求以这两个点到它们的 LCA 的距离为两条边构成的三角形个数。发现和 LCA 有关,故考虑树上启发式合并。在 LCA 处计算贡献,在枚举该节点其中一棵子树的所有节点时,分两种情况考虑:

(一):这个节点到当前 LCA 的距离作为最小值时,其贡献即为其它子树内(到当前 LCA 的)距离大于该节点的距离的个数乘上两倍该距离减一。

(二):这个节点到当前 LCA 的距离作为较大值时,贡献为两倍其它子树内(到当前 LCA 的)距离小于该节点的距离总和减去其它子树内(到当前 LCA 的)距离小于该节点的个数。

使用树状树组维护即可。时间复杂度 O(n\log^2 n)。具体实现见代码。

code:

#include<bits/stdc++.h>
#define int long long
#define ls(x) ((x)<<1)
#define rs(x) ((x)<<1|1)
#define mid ((l+r)>>1)
#define lowbit(x) ((x)&(-x))
using namespace std;
const int N=6e5+5;
int n,m,T;
struct edge{
    int to,nxt;
}a[N];
int h[N],cnt,siz[N],son[N];
void add(int u,int v){
    a[++cnt].to=v;
    a[cnt].nxt=h[u];
    h[u]=cnt;
}
int ans;

int L[N],R[N],rk,rev[N],dep[N];
struct BIT{
    int t[N];
    void add(int x,int sum){
        if(!x) return;
        for(;x<=n;x+=lowbit(x)) t[x]+=sum;
    }
    int query(int x){
        int sum=0;
        for(;x;x-=lowbit(x)) sum+=t[x];
        return sum;
    }
}b1,b2;
void dfs1(int u,int fa){
    siz[u]=1;L[u]=++rk;rev[rk]=u;
    dep[u]=dep[fa]+1;
    for(int i=h[u];i;i=a[i].nxt){
        int v=a[i].to;if(v==fa) continue;
        dfs1(v,u);siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]) son[u]=v;
    }
    R[u]=rk;
    return;
}
void Add(int x){
    b1.add(x,x);
    b2.add(x,1);
}
void Del(int x){
    b1.add(x,-x);
    b2.add(x,-1);
}
int Query(int x,int rt){
    int res=0;
    res+=(b2.query(n)-b2.query(x))*(2*(x-rt)-1);//dis=x-rt
    res+=(2*b1.query(x)-2*b2.query(x)*rt-b2.query(x));//n*(2*(down-rt)-1)
    return res;
}
void dfs2(int u,int fa,bool keep){
    for(int i=h[u];i;i=a[i].nxt){
        int v=a[i].to;if(v==son[u]||v==fa) continue;
        dfs2(v,u,0); 
    }
    if(son[u]) dfs2(son[u],u,1);
    int kk=ans;
    for(int i=h[u];i;i=a[i].nxt){
        int v=a[i].to;if(v==fa||v==son[u]) continue;
        for(int j=L[v];j<=R[v];j++) ans+=Query(dep[rev[j]],dep[u]);
        for(int j=L[v];j<=R[v];j++) Add(dep[rev[j]]);
    }
    // cerr<<u<<':'<<ans-kk<<'\n';
    Add(dep[u]);
    if(!keep){
        for(int i=L[u];i<=R[u];i++) Del(dep[rev[i]]);
    }
}   
void init(){
    for(int i=1;i<=n;i++) b1.t[i]=b2.t[i]=0;
    for(int i=1;i<=n;i++) L[i]=R[i]=rev[i]=dep[i]=0;
    for(int i=1;i<=n;i++) a[i]=edge{0,0},h[i]=siz[i]=son[i]=0;
    rk=cnt=ans=0;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>T;
    while(T--){
        init();
        cin>>n;
        for(int i=1,u,v;i<n;i++){
            cin>>u>>v;
            add(u,v);add(v,u);
        }
        dfs1(1,0);
        dfs2(1,0,1);
        cout<<ans<<'\n';
    }
    return 0;
}