CF2164F2
xujindong_ · · 题解
对于每条从根出发的链,上面的点的排名是确定的。考虑如何刻画大小关系,可以建出一张图。从上往下,到达一个节点
拓扑序计数通常不可做,但这里图形态很特殊,总形如若干弓形的叠加和嵌套(加入两个虚点
我们从上面的弓形(也就是树上自底向上)开始,将上面的两条边和下面的一条边合并,这两串点是平行关系,可以任意插空。设
问题变为建图,DFS 时用平衡树维护祖先的大小关系,支持插入、删除和查第
#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int c,n,rt,fac[500005],vac[500005],a[500005];
vector<int>e[500005];
map<int,int>f[500005],g[500005];
template<typename T>struct node{
int tr[2],sz,r;
T v;
};
template<typename T,int maxn>struct FHQtreap{
node<T>tr[maxn];
int cnt;
void pushup(int pos){
tr[pos].sz=tr[tr[pos].tr[0]].sz+tr[tr[pos].tr[1]].sz+1;
}
void split(int pos,int k,int&a,int&b){
if(!pos){
a=b=0;
return;
}
if(tr[tr[pos].tr[0]].sz<k)a=pos,split(tr[pos].tr[1],k-tr[tr[pos].tr[0]].sz-1,tr[a].tr[1],b);
else b=pos,split(tr[pos].tr[0],k,a,tr[b].tr[0]);
pushup(pos);
}
int merge(int a,int b){
if(!a||!b)return a^b;
if(tr[a].r<tr[b].r)return tr[a].tr[1]=merge(tr[a].tr[1],b),pushup(a),a;
else return tr[b].tr[0]=merge(a,tr[b].tr[0]),pushup(b),b;
}
void insert(int&rt,int x,T k,int a=0,int b=0){
split(rt,x,a,b),tr[++cnt].sz=1,tr[cnt].tr[0]=tr[cnt].tr[1]=0,tr[cnt].r=rand(),tr[cnt].v=k,rt=merge(merge(a,cnt),b);
}
void erase(int&rt,int x,int a=0,int b=0,int c=0){
split(rt,x-1,a,b),split(b,1,b,c),rt=merge(a,c);
}
T kth(int&rt,int x,int a=0,int b=0,int c=0,int ans=0){
return split(rt,x-1,a,b),split(b,1,b,c),ans=tr[b].v,rt=merge(merge(a,b),c),ans;
}
};
FHQtreap<int,500005>t;
int C(int n,int m){
return n<0||m<0||n<m?0:1ll*fac[n]*vac[m]%mod*vac[n-m]%mod;
}
void dfs(int pos){
int u=t.kth(rt,a[pos]+1),v=t.kth(rt,a[pos]+2);
t.insert(rt,a[pos]+1,pos),f[u][pos]=f[pos][v]=1;
for(int i=0;i<e[pos].size();i++)dfs(e[pos][i]);
f[u][v]=1ll*f[u][v]*f[u][pos]%mod*f[pos][v]%mod*C(g[u][v]+g[u][pos]+g[pos][v]+1,g[u][v])%mod;
g[u][v]+=g[u][pos]+g[pos][v]+1;
t.erase(rt,a[pos]+2);
}
int main(){
ios::sync_with_stdio(0),cin.tie(0),srand(time(0)),fac[0]=vac[0]=vac[1]=1;
for(int i=2;i<=500000;i++)vac[i]=1ll*vac[mod%i]*(mod-mod/i)%mod;
for(int i=1;i<=500000;i++)fac[i]=1ll*fac[i-1]*i%mod,vac[i]=1ll*vac[i]*vac[i-1]%mod;
cin>>c;
while(c--){
cin>>n;
for(int i=2,f;i<=n;i++)cin>>f,e[f].push_back(i);
for(int i=1;i<=n;i++)cin>>a[i];
t.insert(rt,1,0),t.insert(rt,2,n+1),f[0][n+1]=1,dfs(1),cout<<f[0][n+1]<<'\n';
for(int i=0;i<=n;i++)e[i].clear(),f[i].clear(),g[i].clear();
rt=t.cnt=0;
}
return 0;
}