题解 P4323 【[JSOI2016]独特的树叶】

· · 题解

题目链接:传送门

Update1:更新了\LaTeX

思路:

思路参考了这篇博客的代码。

我选择的哈希方式是使用质数表进行哈希,即:

f(x)=\sum\limits_{y\in Son_x}f(y)\times\text{P}(Size(y))\mod M

其中 P(x) 表示质数表中的第 x 项, Size(x) 表示以 x 为根的子树的大小, f(x) 初始化为 1 ,M 为模数。

对于一棵有根树来说,这种方法可以在 \mathcal{O(n)} 的时间内处理出该树的哈希值。

而题目是一颗有根树,可以想到一种朴素做法,即以每个节点为根跑一边哈希,时间复杂度为 \mathcal{O(n^2)} 无法满足题目的数据范围。

以题目样例给的图为例:

可以先求出以 x 为根的树的哈希值,记为 f(x) ,同时被一并求出的还有以 y 为根的子树的哈希值,记为 f(y)

根据上面的定义我们可以推导出以 y 为根的哈希值 g(y)f(y)+h(x)

根据定义可得: $$h(x)=(f(x)-f(y)\times\text{P}(Size(y))\times\text{P}(n-Size(y))\mod M$$ 对于任意的 $y$ ,有: $$g(y)=f(y)+(f(x)-f(y)\times\text{P}(Size(y))\times\text{P}(n-Size(y))\mod M$$ 这样就可以通过一次 DFS 来求出以任意节点为根的树的哈希值了,时间复杂度降为 $\mathcal{O(n)}

我们可以使用 std::set 来维护第一颗树的哈希值。

而对于第二棵树,我们可以先记录下该树的叶子节点,并且通过同样的方式计算出每个节点为根的哈希值。

再通过枚举计算删除一个叶子节点后该树的哈希值并将该值与先前 std::set 中的值进行比较就可判断删除叶子节点后的树是否与第一颗树同构。

特别的,由于叶子节点的大小为 1 ,所以删除一个叶子节点后的哈希值为以该叶子结点的父亲作为根的哈希值减去 \text{P}(1) ,可以自己思考一下为什么。

时间复杂度分析:

哈希预处理的复杂度为 \mathcal{O(n)}

处理质数表时间复杂度为线性,即 \mathcal{O(n)}

std::set 的本质是一颗红黑树,单次查询的时间复杂度为 \mathcal{O(\log n)}

总的时间复杂度为 \mathcal{O(n \log n)}

代 码 放 送:

既然你能找到这题,我相信你能瞬间做出来的。

Code:
#include<bits/stdc++.h>
#define Max(x,y) (x>y?x:y)
#define Min(x,y) (x<y?x:y)
using namespace std;
const long long N=100010,M=1000010,lpw=1e9+7;
long long head[N],ver[M],Next[M],tot;
long long n,P[M<<1],v[M<<1],cnt;
long long f[N],g[N],Size[N];
long long ind[N],fa[N];
set<long long> st;
void add(long long x,long long y){
    ver[++tot]=y,Next[tot]=head[x],head[x]=tot;
}
void init(){
    memset(head,0,sizeof(head)),tot=0,n++;
}
void Hash1(long long x,long long fa){
    Size[x]=f[x]=1;
    for(long long i=head[x];i;i=Next[i]){
        long long y=ver[i];
        if(y==fa)continue;
        Hash1(y,x);
        Size[x]+=Size[y];
        f[x]=(f[x]+f[y]*P[Size[y]]%lpw)%lpw;
    }
}
void Hash2(long long x,long long fa,long long L){
    g[x]=(f[x]+L*P[n-Size[x]]%lpw)%lpw;
    for(long long i=head[x];i;i=Next[i]){
        long long y=ver[i];
        if(y==fa)continue;
        Hash2(y,x,(g[x]-f[y]*P[Size[y]]%lpw+lpw)%lpw);
    }
}
void GetP(long long x){
    for(long long i=2;i<=x;i++){
        if(!v[i])P[++cnt]=i;
        for(long long j=1;j<=cnt&&i*P[j]<=x;j++){
            v[i*P[j]]=1;
            if(!(i%P[j]))break;
        }
    }
}
int main(){
    GetP((M<<1)-1);
    scanf("%lld",&n);
    for(long long i=1;i<n;i++){
        long long x,y;
        scanf("%lld%lld",&x,&y);
        add(x,y),add(y,x);
    }
    Hash1(1,0);
    Hash2(1,0,0);
    for(long long i=1;i<=n;i++)
        st.insert(g[i]);
    init();
    for(long long i=1;i<n;i++){
        long long x,y;
        scanf("%lld%lld",&x,&y);
        add(x,y),add(y,x);
        ind[x]++,ind[y]++;
        fa[x]=y,fa[y]=x;
    }
    Hash1(1,0);
    Hash2(1,0,0);
    for(long long i=1;i<=n;i++){
        if(ind[i]>1)continue;
        int x=(g[fa[i]]-P[1]+lpw)%lpw;
        if(st.find(x)!=st.end()){
            printf("%lld\n",i);
            return 0;
        }
    }
    return 0;
}