题解:CF2164F2 Chain Prefix Rank (Hard Version)
详细揭秘不会求 dag 拓扑序应该如何解决这道题。
小引理:对于一个排列
于是我们相当于已经知道树上任意两个为祖先-后代的关系的点的权值的大小关系,现在要求方案数。
可以像这样做:对于一个点
连出的 dag 的拓扑序个数即为答案。
但是问题在于我们并不会求 dag 的拓扑序,所以大概想到,要在树的结构上考虑。
我们考察这样计算:对于一个点
然后我们考虑如何从只考虑与
我们考察将
如图,链上有四个点,分成了五个等价类。
然后我们加入
考虑如何计算增加了限制后的方案数。
实际上就是,原本这一个等价类里面的数可以随便排,然后被切成两个等价类,现在这两个等价类内部可以随便排。
设插入一个
意思是原本算的方案里,这个等价类可以随便排,现在拆成两个了,先除一个随便排,再乘上两个随便排。
更形式化的说,设
这里还有另一种理解方法,我们知道目前有
于是我们只要计算出这个和子树相关的信息即可。
但是我们发现题目的给出的信息还是有点抽象,我想要维护出大小关系只能用平衡树维护一条链,因此很难计算子树内维护两个点所对应的数之间的点的个数。
注意到,我们可以求出一组满足限制条件的初始解,这样我们就能快速的比较大小关系,而且也不会比较错。
这部分随便跑一组拓扑序就行,然后后面的部分就只需要做一个二维数点,复杂度
/*
目前以 u 为根的子树已经满足了 fau 到 1 链上的点的大小关系,有一个方案数
插入一个 u,限制会变紧
设 u 前驱是 x,后继是 y
原本位于 [x,y] 这一段里面的数,现在有一部分要去 [x,u] 有一部分要去 [u,y]
除掉原本 [x,y] 这一段的贡献,乘上 [x,u],[u,y] 这两端的贡献
*/
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define lowbit(x) (x&(-x))
const int mod = 998244353;
const int M = 1000000;
int qp(int p,int q){
int ans = 1,pro = p;
while(q){
if(q&1)ans = ans*pro%mod;
pro = pro*pro%mod;q>>=1;
}
return ans;
}
int jie[1000005],inv[1000005];
void init(){
jie[0] = 1;for(int i = 1;i<=M;i++)jie[i] = jie[i-1]*i%mod;
inv[M] = qp(jie[M],mod-2);
for(int i = M-1;i>=0;i--)inv[i] = inv[i+1]*(i+1)%mod;
}
int n,m;
struct BIT{
int tree[500005];
void clear(){for(int i = 0;i<=n+2;i++)tree[i] = 0;}
void upd(int pos,int add){
for(int i = pos;i<=n+2;i+=lowbit(i))tree[i]+=add;
}
int query(int pos){
int res = 0;
for(int i = pos;i>0;i-=lowbit(i))res+=tree[i];
return res;
}
}T;
int fa[1000005],a[1000005];
vector<int>p[1000005];
bool OK = 1;
int rt;
int ls[1000005],rs[1000005],rnd[1000005],sz[1000005];
int lst[1000005],nxt[1000005];
int val[1000005];//求出的一组解
void push_up(int k){sz[k] = sz[ls[k]]+sz[rs[k]]+1;}
void split_(int now,int k,int& x,int& y){
if(!now){x = y = 0;return;}
if(k>=sz[ls[now]]+1){
x = now;
split_(rs[now],k-sz[ls[now]]-1,rs[now],y);
}else{
y = now;
split_(ls[now],k,x,ls[now]);
}
push_up(now);
}
int merge(int x,int y){
if(!x or !y)return x|y;
if(rnd[x]<rnd[y]){
rs[x] = merge(rs[x],y);
push_up(x);
return x;
}else{
ls[y] = merge(x,ls[y]);
push_up(y);
return y;
}
}
int find_first(int k){while(ls[k])k = ls[k];return k;}
int find_last(int k){while(rs[k])k = rs[k];return k;}
void add(int id){
//插入到第 aid 个位置后
int x,y;
split_(rt,a[id]+1,x,y);//前面有个 0,多加 1
lst[id] = find_last(x),nxt[id] = find_first(y);
assert(x and y);
rt = merge(merge(x,id),y);
}
void del(int id){
int x,y,z;
split_(rt,a[id]+1,x,y);
split_(y,1,y,z);
rt = merge(x,z);
}
int nw = 0;
int dfn[500005],ssz[500005],b[500005];
void dfs(int now,int d){
if(a[now]>d){OK = 0;return;}
dfn[now] = ++nw;
ssz[now] = 1;
b[nw] = now;
add(now);
for(auto x:p[now])dfs(x,d+1),ssz[now]+=ssz[x];
del(now);
}
int in[500005];
vector<int>pp[500005];
void add(int x,int y){pp[x].push_back(y);in[y]++;}
void topo(){
queue<int>q;
for(int i = 1;i<=n+2;i++)if(!in[i])q.push(i);
int cc = 0;
while(!q.empty()){
int now = q.front();q.pop();
val[now] = ++cc;
for(auto x:pp[now])if(--in[x] == 0)q.push(x);
}
assert(cc == n+2);
//求出一组解
}
int pro = 1,tot = 0;
int l1[1500005],r1[1500005],l2[1500005],r2[1500005],ans[1500005];
bool f[1500005];
vector<pair<int,int> >ll[500005];
void add(int L1,int R1,int L2,int R2,int F){
++tot;
l1[tot] = L1,r1[tot] = R1,l2[tot] = L2,r2[tot] = R2,f[tot] = F;
}
void work(){
for(int i = 1;i<=tot;i++)ll[l1[i]-1].push_back({i,-1}),ll[r1[i]].push_back({i,1});
int sum = 0;
for(int i = 1;i<=n;i++){
//加入 bi
T.upd(val[b[i]],1);
for(auto x:ll[i]){
int id = x.first,f = x.second;
ans[id] += f*(T.query(r2[id])-T.query(l2[id]-1));
}
}
}
void solve(){
for(int i = 0;i<=tot;i++)ans[i] = 0;
T.clear();nw = 0;tot = 0;
for(int i = 0;i<=n+2;i++)in[i] = 0,p[i].clear(),pp[i].clear(),ll[i].clear();
cin >> n;
for(int i = 2;i<=n;i++)cin>>fa[i];
for(int i = 1;i<=n;i++)cin>>a[i];
for(int i = 2;i<=n;i++)p[fa[i]].push_back(i);
for(int i = 1;i<=n+2;i++)ls[i] = rs[i] = 0,sz[i] = 1,rnd[i] = rand();
rt = merge(n+2,n+1);
OK = 1;pro = jie[n];
dfs(1,0);
for(int i = 1;i<=n;i++)add(lst[i],i),add(i,nxt[i]);
topo();
if(!OK){cout << 0 << '\n';return;}
for(int i = 1;i<=n;i++){
add(dfn[i],dfn[i]+ssz[i]-1,val[lst[i]],val[nxt[i]],0);
add(dfn[i]+1,dfn[i]+ssz[i]-1,val[lst[i]],val[i],1);
add(dfn[i]+1,dfn[i]+ssz[i]-1,val[i],val[nxt[i]],1);
// i 本身不计入
}
work();
for(int i = 1;i<=tot;i++){
if(f[i])pro = pro*jie[ans[i]]%mod;
else pro = pro*inv[ans[i]]%mod;
}
cout << pro << '\n';
}
signed main(){
srand(time(0));
ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
init();
int t;cin >> t;
while(t--)solve();
return 0;
}