[ABC359G] Sum of Tree Distance 题解

· · 题解

该题为 CF1725E 弱化版。评价是 G 比 C 简单。

考虑使用虚树,把每种颜色的点抽出来建立虚树,在虚树上跑一遍 dfs 直接算贡献即可。具体地,考虑虚树上每一条边的贡献:对于这条边连接的两个子树,分别考虑两个子树内有多少个颜色为当前考察的颜色的点,然后乘上这条边的长度。

虚树上一条边的长度定义为其连接的两个结点在原树中的距离,这个可以通过预处理深度以及倍增处理 \mathrm{lca} 来实现。正确性显然。

时限很大所以不怕大常数。

放代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
class virtual_tree{
  private:
    int n,k;
    vector<int> dep,dfn,f,e;
    vector<vector<int> > g,s;
  public:
    virtual_tree(vector<vector<int> > &t){
      n=t.size(),k=__lg(n),f.resize(n);
      dep.resize(n),dfn.resize(n),e.resize(n),g.resize(n);
      s.resize(n,vector<int>(k+1));
      int o=0;
      function<void(int,int)> dfs=[&](int u,int f){
        dfn[u]=o++,s[u][0]=f;
        for(int i=1;i<=k;i++)
          s[u][i]=s[s[u][i-1]][i-1];
        for(int i:t[u])
          if(i!=f)dep[i]=dep[u]+1,dfs(i,u);
      };
      dfs(0,0);
    }
    inline int lca(int u,int v){
      if(dep[u]<dep[v])swap(u,v);
      for(int i=k;~i;i--)
        if(dep[s[u][i]]>=dep[v])u=s[u][i];
      if(u==v)return u;
      for(int i=k;~i;i--)
        if(s[u][i]!=s[v][i])u=s[u][i],v=s[v][i];
      return s[u][0];
    }
    inline int build(vector<int> &c){
      auto v=c;
      sort(v.begin(),v.end(),[&](int x,int y){
        return dfn[x]<dfn[y];
      });
      int n0=v.size();
      for(int i=1;i<n0;i++)
        v.emplace_back(lca(v[i-1],v[i]));
      sort(v.begin(),v.end(),[&](int x,int y){
        return dfn[x]<dfn[y];
      });
      int n=unique(v.begin(),v.end())-v.begin(),s=0;
      for(int i=1;i<n;i++)
        g[lca(v[i-1],v[i])].emplace_back(v[i]);
      // 建立虚树
      for(int i:c)e[i]=1;
      function<void(int)> dfs=[&](int u){
        for(int i:g[u]){
          dfs(i),e[u]+=e[i];
          s+=e[i]*(c.size()-e[i])*(dep[i]-dep[u]);
        } // 跑贡献
      };
      dfs(v[0]);
      for(int i=0;i<n;i++)
        vector<int>().swap(g[v[i]]),e[v[i]]=0;
      // 记得清空
      return s;
    }
};
main(){
  ios::sync_with_stdio(false);
  int n,r=0; cin>>n;
  vector<int> a(n);
  vector<vector<int> > V(n),g(n);
  for(int i=1;i<n;i++){
    int u,v; cin>>u>>v;
    g[--u].emplace_back(--v);
    g[v].emplace_back(u);
  }
  for(int i=0;i<n;i++)
    cin>>a[i],V[--a[i]].emplace_back(i);
  virtual_tree t(g);
  for(int i=0;i<n;i++)
    if(V[i].size()>1)r+=t.build(V[i]);
  cout<<r<<endl;
  return 0;
}