P5904 [POI2014] HOT-Hotels 加强版 题解

· · 题解

线段树合并+点分治。

容易想到枚举中心点(即到三点的距离相同的点)考虑,如果把树视作有根树,至少有两个点是在中间点的子树中的,此时它们只需要满足深度相同就满足条件。

我们只需要在这两个节点所在的子树外再找一个到中心点距离相同的节点就可以了。通过二次离线和简单容斥,我们可以将问题转化为求 O(n\log n) 次距离 xk 的节点个数,点分治即可。

一个比较关键的点是我们只有在线段树合并两颗树都递归到叶子节点的时候才会产生询问,而根据线段树合并的复杂度,我们会产生 O(n\log n) 次询问。

最终复杂度 O(n\log^2 n)

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int n;
vector<int> lin[N];
struct Q{
    int k,cnt1,cnt2,ans;
};
vector<Q> vec[N]; 
#define lc(x) tr[x].lc
#define rc(x) tr[x].rc
#define cnt(x) tr[x].cnt
struct seg{
    int lc,rc,cnt;
}tr[N*30];
int tot;
int dep[N],rt[N];
void add(int &s,int l,int r,int wz){
    if(!s) s=++tot;
    cnt(s)++;
    if(l==r) return;
    int mid=l+r>>1;
    if(wz<=mid) add(lc(s),l,mid,wz);
    else        add(rc(s),mid+1,r,wz);
}
int mer(int x,int y,int l,int r,int X){
    if(!x||!y) return x|y;
    if(l==r) vec[X].push_back({l-dep[X],cnt(x),cnt(y),0});
    cnt(x)+=cnt(y);
    int mid=l+r>>1;
    return lc(x)=mer(lc(x),lc(y),l,mid,X),rc(x)=mer(rc(x),rc(y),mid+1,r,X),x; 
}
void DFS1(int x,int fa){
    add(rt[x],0,n,dep[x]);
    for(int to:lin[x]) if(to!=fa) dep[to]=dep[x]+1,DFS1(to,x),rt[x]=mer(rt[x],rt[to],0,n,x);
}

int bjt[N],siz[N],mx[N],root,sum_siz;
void DFS_root(int x,int fa){
    siz[x]=1,mx[x]=0;
    for(int to:lin[x]) if(to!=fa&&!bjt[to]) DFS_root(to,x),siz[x]+=siz[to],mx[x]=max(mx[x],siz[to]);
    mx[x]=max(mx[x],sum_siz-siz[x]);
    if(root==-1||mx[x]<mx[root]) root=x;
}
int find_root(int x){
    return sum_siz=siz[x],root=-1,DFS_root(x,-1),root;
}

int cnt[N];
void DFS_get_son(int x,int fa,vector<int> &nd){
    dep[x]=dep[fa]+1,nd.push_back(x);
    for(int to:lin[x]) if(to!=fa&&!bjt[to]) DFS_get_son(to,x,nd); 
}
void calc(int x){
    dep[x]=0;
    vector<vector<int>> v;
    v.push_back(vector<int>{x});
    for(int to:lin[x]) if(!bjt[to]) v.push_back(vector<int>()),DFS_get_son(to,x,v.back());

    for(auto &TO:v) for(int to:TO) cnt[dep[to]]++;
    for(auto &TO:v){
        for(int to:TO) cnt[dep[to]]--;
        for(int to:TO) for(Q &el:vec[to]){
            if(el.k>=dep[to]) el.ans+=cnt[el.k-dep[to]];
        }
        for(int to:TO) cnt[dep[to]]++;
    }
    for(auto &TO:v) for(int to:TO) cnt[dep[to]]--;
}
void solve(int zx){
    calc(zx);
    bjt[zx]=1;
    for(int to:lin[zx]) if(!bjt[to]) solve(find_root(to));
}
int ask_dfs(int x,int fa,int k,int cd){
    int ans=0;
    for(int to:lin[x]) if(to!=fa) ans+=ask_dfs(to,x,k,cd+1);
    return ans+(cd==k);
}
signed main(){
    cin.tie(0)->sync_with_stdio(0);
    cin>>n;
    for(int i=1,u,v;i<n;i++) cin>>u>>v,lin[u].push_back(v),lin[v].push_back(u);
    DFS1(1,-1);
    siz[1]=n,solve(find_root(1));

    int ans=0;
    for(int i=1;i<=n;i++) for(auto to:vec[i]) ans+=(to.ans-to.cnt1-to.cnt2)*to.cnt1*to.cnt2;
    cout<<ans; 
    return 0;
}