[ARC198D] Many Palindromes on Tree 题解

· · 题解

对于树上 xy 的简单路径 v_1=x,v_2,\ldots,v_{k-1},v_k=y,令 \mathrm{next}(x,y) 的值为一个二元组 (v_2,v_{k-1});如果 x,y 之间的距离 \le 1,令 \mathrm{next}(x,y)=(-1,-1)。这个信息可以通过 N 遍 dfs O(N^2) 预处理出来。

题目的条件等价于限定了若干组 (u,v) 需要满足 x_u=x_v,而想要让最终的答案越小,x 中的元素显然要尽量不同。使用并查集维护 x_u=x_v 的条件,考虑所有 A_{i,j}=1 的点对 (i,j),需要满足 x_i=x_j,令 (i',j')=\mathrm{next}(i,j),还需要满足 x_{i'}=x_{j'}……以此类推。直接模拟这个合并过程是 O(N^3) 的,如果一个点对 (i,j) 已经被考虑过(可以用一个 bool 数组记录一下),那么直接退出循环,这样时间复杂度就是 O(N^2) 的(每个点对最多被遍历一次)。

这里有一个坑点:判断点对是否被考虑过,千万不能直接判断并查集中它们是否在同一个集合中!因为可能有一条限制是 x_u=x_v,另一条限制是 x_v=x_w,考虑到点对 (u,w) 的时候由于在并查集中两者所在集合相同,所以直接 break 后,\mathrm{next}(u,w) 就没有被合并,导致出现错误。赛时因为这个错误被这题击败了,如果你最后没反应过来,十万年都调不出来。

知道了 x 序列的值后统计答案就是简单的,直接路径哈希:h_{u,v} 表示 u\to v 上的结点的 x 的值顺序拼接形成的序列哈希值,一条路径 (u,v) 是回文的当且仅当 h_{u,v}=h_{v,u},而 h 也是可以 N 遍 dfs 预处理的,时间复杂度 O(N^2)

所以总的时间复杂度 O(N^2),可以通过。

放代码:

#include<bits/stdc++.h>
#include<atcoder/all>
using namespace std;
typedef pair<int,int> pii;
const int B=6907;
int main(){
  ios::sync_with_stdio(false);
  cin.tie(0); cout.tie(0);
  int n,c=0; cin>>n;
  vector<vector<int> > g(n);
  for(int i=1;i<n;i++){
    int u,v; cin>>u>>v;
    g[--u].emplace_back(--v);
    g[v].emplace_back(u);
  }
  vector nx(n,vector<pii>(n,make_pair(-1,-1)));
  // 对应文章中的 next 数组
  for(int r=0;r<n;r++){
    vector<int> v;
    auto dfs=[&](auto &&self,int u,int f)->void{
      if(v.size()>2)nx[r][u]=make_pair(v[1],v[v.size()-2]);
      for(int i:g[u])
        if(i!=f)v.emplace_back(i),self(self,i,u),v.pop_back();
    };
    v={r},dfs(dfs,r,r);
  } // dfs 求 next 的值
  atcoder::dsu s(n);
  vector b(n,vector<bool>(n));
  // 标记一个点对是否被考虑过
  for(int i=0;i<n;i++)
    for(int j=0;j<n;j++){
      char x; cin>>x;
      if(i==j)continue;
      if(x&1){
        int u=i,v=j;
        while(~u&&!b[u][v])
          b[u][v]=true,s.merge(u,v),tie(u,v)=nx[u][v];
      }
    } // 往上跳 next,并查集维护合并
  vector<int> ld(n);
  for(int i=0;i<n;i++)
    ld[i]=s.leader(i)+1;
  vector h(n,vector<unsigned long long>(n));
  for(int r=0;r<n;r++){
    auto dfs=[&](auto &&self,int u,int f)->void{
      for(int i:g[u])
        if(i!=f)h[r][i]=h[r][u]*B+ld[i],self(self,i,u);
    }; // 计算路径哈希
    h[r][r]=ld[r],dfs(dfs,r,r);
  }
  for(int i=0;i<n;i++)
    for(int j=0;j<n;j++)
      if(h[i][j]==h[j][i])c++;
  cout<<c<<endl;
  return 0;
}