Happy Life in University

· · 题解

题目大意

给你一棵 n 个点的二叉树,每个节点有一个颜色 a_i,定义 f(u,v)u \rightarrow v 的简单路径上的颜色个数。对于所有点对 (u,v),其中 u,v 可以相等,设 lu,v 的最近公共祖先,求 max\{f(l,u)\times f(l,v)\}

思路

这道题我们可以给每个点重新编号,按照深度优先搜索的顺序。然后我们就把每个子树转成了一段连续的区间。然后我们去枚举最近公共祖先,利用线段树维护以我当前枚举的这个最近公共祖先的子树中的答案最大值和次大值,这个子树的答案就是最大值乘以次大值。

代码

#include<bits/stdc++.h>
#define ll long long
using namespace std;
ll n;
ll dfn[300005],dfn2[300005];
vector<ll> g[300005],jian[300005];
ll mx[1200005],tag[1200005];
stack<ll> st[300005];
ll cnt=0;
ll a[300005];
void dfs(ll x) {
    dfn[x]=++cnt;
    for (auto i:g[x]) {
        dfs(i);
    }
    dfn2[x]=cnt;
}
ll ans;
void cj(ll x) {
    if(!st[a[x]].empty()) {
        jian[st[a[x]].top()].push_back(x);
    }
    st[a[x]].push(x);
    for(auto i:g[x]) {
        cj(i);
    }
    st[a[x]].pop();
}
ll lc(ll p) {
    return p<<1;
}
ll rc(ll p) {
    return p<<1|1;
}
void pushdown(ll p) {
    if(tag[p]) {
        tag[lc(p)]+=tag[p],mx[lc(p)]+=tag[p];
        tag[rc(p)]+=tag[p],mx[rc(p)]+=tag[p];
        tag[p]=0;
    }
}
void pushup(ll p) {
    mx[p]=max(mx[lc(p)],mx[rc(p)]);
}
void upd(ll p,ll l,ll r,ll ql,ll qr,ll x) {
    if(ql<=l&&r<=qr) {
        tag[p]+=x;
        mx[p]+=x;
        return ;
    }
    pushdown(p);
    ll mid=(l+r)>>1;
    if(ql<=mid) {
        upd(lc(p),l,mid,ql,qr,x);
    }
    if(mid<qr) {
        upd(rc(p),mid+1,r,ql,qr,x);
    }
    pushup(p);
}
ll query(ll p,ll l,ll r,ll ql,ll qr) {
    if(ql<=l&&r<=qr) {
        return mx[p];
    }
    pushdown(p);
    ll mid=(l+r)>>1;
    ll ret=0;
    if(ql<=mid) {
        ret=max(ret,query(lc(p),l,mid,ql,qr));
    }
    if(mid<qr) {
        ret=max(ret,query(rc(p),mid+1,r,ql,qr));
    }
    return ret;
}
void solve(ll x) {
    for (auto i:g[x]) {
        solve(i);
    }
    upd(1,1,n,dfn[x],dfn2[x],1);
    for (auto i:jian[x]) {
        upd(1,1,n,dfn[i],dfn2[i],-1);
    }
    ll mx1=1,mx2=1;
    for (auto i:g[x]) {
        ll res=query(1,1,n,dfn[i],dfn2[i]);
        if(res>=mx1) {
            mx2=mx1;
            mx1=res;
        } else if(res>=mx2) {
            mx2=res;
        }
    }
    ans=max(ans,mx1*mx2);
}
int main() {
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    ll t;
    cin>>t;
    while(t--) {
        cnt=0;
        ans=1;
        cin>>n;
        for(int i=1; i<=4*n; i++) {
            mx[i]=0;
            tag[i]=0;
        }
        for (int i=1; i<=n; i++) {
            g[i].clear();
            jian[i].clear();
            dfn[i]=0;
            dfn2[i]=0;
            while(!st[i].empty()) {
                st[i].pop();
            }
        }
        for(int i=2; i<=n; i++) {
            ll fa;
            cin>>fa;
            g[fa].push_back(i);
        }
        for (int i=1; i<=n; i++) {
            cin>>a[i];
        }
        dfs(1);
        cj(1);
        solve(1);
        cout<<ans<<endl;
    }
    return 0;
}