题解:CF2231E Graph Cutting

· · 题解

感谢这题送我上 Candidate Master!

枚举三者之一的点 u 把它拎起来当根,此时问题转化为算有多少点对 (v,w) 满足 dep_v+dep_w-dep_{{\rm LCA}(v,w)}+1=d

直接启发式合并枚举 \rm LCA 就可以单次 O(n\log n) 算了。具体来说做一个 dfs,每次先单独走轻子树统计答案,统计完清空深度桶;然后走重儿子统计答案;然后依次暴力扫每个轻儿子的子树利用深度桶统计答案,走完一个轻儿子加上这个轻儿子的桶的贡献。最后统计根对所有儿子的答案即可(可以看代码)。

最后答案除以 3 即可。总复杂度 O(n^2\log n)

::::success[Code]

#include<bits/stdc++.h>
using namespace std;

#define int long long
#define MAXN 2005

int n,d,a[MAXN],dep[MAXN],siz[MAXN],son[MAXN],cnt[MAXN];
int Ans = 0;
vector<int> E[MAXN];

void init( int x , int fa ){
    siz[x] = 1,son[x] = 0;
    for( int v : E[x] ){
        if( v == fa ) continue;
        dep[v] = dep[x] + 1,init( v , x ),siz[x] += siz[v];
        if( !son[x] || siz[v] > siz[son[x]] ) son[x] = v;
    }
}

void modi_single( int x , int fa , int k ){
    cnt[dep[x]] += k;
    for( int v : E[x] ){
        if( v == fa ) continue;
        modi_single( v , x , k );
    }
}

void chk_single( int x , int fa , int aim ){
    if( aim - dep[x] >= 0 && aim - dep[x] < n ) Ans += cnt[aim - dep[x]];
    for( int v : E[x] ){
        if( v == fa ) continue;
        chk_single( v , x , aim );
    }
}

void calc( int x , int fa ){
    for( int v : E[x] ){
        if( v == fa ) continue;
        if( v != son[x] ){
            calc( v , x );
            modi_single( v , x , -1 );
        }
    }
    if( son[x] ) calc( son[x] , x );
    int Aim = d - 1 + dep[x];
    for( int v : E[x] ){
        if( v == son[x] || v == fa ) continue;
        chk_single( v , x , Aim ),modi_single( v , x , 1 );
    }
    if( fa ){
        if( Aim - dep[x] >= 0 && Aim - dep[x] < n ) Ans += cnt[Aim - dep[x]];
        cnt[dep[x]] ++;
    }
}

inline void solve(){
    scanf("%lld%lld",&n,&d);
    for( int i = 1 ; i < n ; i ++ ){
        int u,v; scanf("%lld%lld",&u,&v);
        E[u].emplace_back( v ),E[v].emplace_back( u );
    }
    for( int rt = 1 ; rt <= n ; rt ++ ){
        init( rt , 0 ),calc( rt , 0 );
        for( int i = 0 ; i <= n ; i ++ ) cnt[i] = siz[i] = son[i] = dep[i] = 0;
    }
    printf("%lld\n",Ans / 3);
    for( int i = 0 ; i <= n ; i ++ ) a[i] = dep[i] = siz[i] = son[i] = cnt[i] = 0,E[i].clear();
    Ans = 0;
}

signed main(){
    int t; scanf("%lld",&t);
    while( t -- ) solve();
    return 0;
}

::::