CF2063E Triangle Tree

· · 题解

先只考虑产生贡献的方案。

dis_1=dist(u,lca(u,v)), dis_2=dist(v,lca(u,v)),且 dis_1\ge dis_2,第三边长为 x,那么有 dis_1-dis_2< x< dis_1+dis_2,容易发现 x 的取值方案有 2\times dis_2-1 种。所以当 f\left ( u,v \right )\ne 0 时,我们可以令 f\left ( u,v \right )= 2\times min\left ( dist(u,lca(u,v)),dist(v,lca(u,v)) \right ) -1

首先把 -1 的贡献抽出来,也就是所有的选择方案数,这个跑一遍树形 dp 就可以求出来。

接下来计算出 \sum_{i=1}^{n}\sum_{j=1}^{n} min\left ( dist(i,lca(i,j)),dist(j,lca(i,j)) \right ) 即可,f(i,j),f(j,i) 正好算了两次。

这个形式已经可以用启发式合并做了,但我们还可以继续拆贡献。

借第二段中的变量继续分析,如果有 dis_1>dis_2 那么我们完全可以先把 u 向上跳到和 v 相同深度的位置 u' 再统计,可以发现这并不会影响答案,然后我们就可以同时把 u'v 暴力向上跳,每跳一次贡献增加 1,跳到同一点 \left ( lca(u,v) \right ) 时退出即可。

当然,直接暴力枚举点对并暴力跳肯定是不行的,这样还不如用前面推出的式子直接求,但这启示我们,我们可以每次只考虑同一深度的节点,每次处理完一层后把这层的点全部跳到上一层,然后进行相同的操作,这样每个节点只访问一次,时间复杂度是 O(n) 的。

考虑同深度的点的整体贡献,显然一个点与任意一个不同的点之间都能产生贡献。所以我们只需求出当前层点数总和就可以快速计算。

最后的答案要记得除去 -1 的贡献。

#include<bits/stdc++.h>
using namespace std;
template <typename T>
void in(T &x){
    char c=getchar(), f=1;
    while ((c<'0' || c>'9') && c!='-') c=getchar();
    if (c=='-') f=-1, c=getchar();
    for (x=0; c>='0' && c<='9'; c=getchar())
        x=x*10+c-'0';
    x*=f;
}
const int N=3e5+5;
struct Node{
    int u,v,nxt;
}a[N*2];
int last[N],cnt,u,v,F[N],sz[N];
void add(int u,int v){
    a[++cnt]={u,v,last[u]};
    last[u]=cnt;
}
int T,n,mxd;
long long ans=0;
vector<int>dep[N];
void dfs(int u,int fa,int depth){
    F[u]=fa;sz[u]=1;
    mxd=max(mxd,depth);
    dep[depth].push_back(u);
    for(int i=last[u];i;i=a[i].nxt){
        int v=a[i].v;
        if(v==fa)continue;
        dfs(v,u,depth+1);
        sz[u]+=sz[v];   
    }
    ans-=n-sz[u]-depth+1;
}
int main(){
//  freopen("E.in","r",stdin);
//  freopen("E.out","w",stdout);
    in(T);
    while(T--){
        in(n);
        for(int i=1;i<=mxd;i++)dep[i].clear();
        for(int i=1;i<=n;i++)last[i]=0;
        cnt=mxd=ans=0;
        for(int i=1;i<n;i++){
            in(u);in(v);
            add(u,v);add(v,u);
        }
        dfs(1,0,1);
        ans/=2;
        for(int i=mxd;i>1;i--){
            int al=0;
            for(auto j:dep[i])al+=sz[j];
            for(auto j:dep[i]){
                if(al==sz[j])break;
                ans+=sz[j]*1ll*(al-sz[j]);
            }
        }
        printf("%lld\n",ans);
    }
    return 0;
}