题解:CF2164F2 Chain Prefix Rank (Hard Version)

· · 题解

详细揭秘不会求 dag 拓扑序应该如何解决这道题。

小引理:对于一个排列 p,设 a_i 表示 \sum_{j<i} [p_j<p_i] 的值,对于一组 0\leq i \leq a_i 的序列 \{a\},存在唯一的排列 p 与之对应。

于是我们相当于已经知道树上任意两个为祖先-后代的关系的点的权值的大小关系,现在要求方案数。

可以像这样做:对于一个点 x,维护出 1x 路径上所有点的相对顺序,在其中找到 x 的前驱,后继,设为 lst_x,nxt_x,连边 lst_x\to x,x\to nxt_x

连出的 dag 的拓扑序个数即为答案。

但是问题在于我们并不会求 dag 的拓扑序,所以大概想到,要在树的结构上考虑。

我们考察这样计算:对于一个点 x,计算 x 的子树内所有点的值与 fa_x1 的这条链上的所有点的值的大小关系满足限制的方案数。

然后我们考虑如何从只考虑与 1\to fa_x 链之间大小关系的方案数转移到考虑了 1\to x 链之间之间大小关系的方案数。

我们考察将 1\to fa_x 这条链上的点按大小顺序写下来,设有 l 个点,那么相对大小关系被划分成了 l+1 个区域,此时我们认为每个区域内的点都是等价的,也就是说如果确定了子树内某个点在哪个区域,就可以在这个区域内随便选一个。

如图,链上有四个点,分成了五个等价类。

然后我们加入 x,中间有个等价类被分成了两半。

考虑如何计算增加了限制后的方案数。

实际上就是,原本这一个等价类里面的数可以随便排,然后被切成两个等价类,现在这两个等价类内部可以随便排。

设插入一个 x,切分出左边等价类是大小是 c_1,右边等价类大小是 c_2,此时对答案的贡献就是 \dfrac{c_1!c_2!}{(c_1+c_2+1)!}(答案乘上了这个数)

意思是原本算的方案里,这个等价类可以随便排,现在拆成两个了,先除一个随便排,再乘上两个随便排。

更形式化的说,设 a_i 表示 i 子树内大小在 lst_i,i 两者之间的点的个数,b_i 表示 i 子树内大小在 i,nxt_i 两者之间的点的个数,答案即为 n!\prod_i \dfrac{a_i!b_i!}{(a_i+b_i+1)!}

这里还有另一种理解方法,我们知道目前有 l+1 个等价类,也知道 x 的所有儿子子树里每个等价类应该有多少个点,我们可以对于每一个等价类点,算出其分配到儿子子树内的方案(就是一个多重组合数,总共有 M 个,每个子树分 c_i 个),但是这样做状态数就是 O(n^2) 的,我 vp 的时候写的就是这个做法,然后把关于每个点的贡献写出来,发现全部可以抵消,也能得到同样的式子。

于是我们只要计算出这个和子树相关的信息即可。

但是我们发现题目的给出的信息还是有点抽象,我想要维护出大小关系只能用平衡树维护一条链,因此很难计算子树内维护两个点所对应的数之间的点的个数。

注意到,我们可以求出一组满足限制条件的初始解,这样我们就能快速的比较大小关系,而且也不会比较错。

这部分随便跑一组拓扑序就行,然后后面的部分就只需要做一个二维数点,复杂度 O(n\log n)

/*
目前以 u 为根的子树已经满足了 fau 到 1 链上的点的大小关系,有一个方案数
插入一个 u,限制会变紧
设 u 前驱是 x,后继是 y
原本位于 [x,y] 这一段里面的数,现在有一部分要去 [x,u] 有一部分要去 [u,y]
除掉原本 [x,y] 这一段的贡献,乘上 [x,u],[u,y] 这两端的贡献 
*/
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define lowbit(x) (x&(-x))
const int mod = 998244353;
const int M = 1000000;
int qp(int p,int q){
    int ans = 1,pro = p;
    while(q){
        if(q&1)ans = ans*pro%mod;
        pro = pro*pro%mod;q>>=1;
    }
    return ans;
}
int jie[1000005],inv[1000005];
void init(){
    jie[0] = 1;for(int i = 1;i<=M;i++)jie[i] = jie[i-1]*i%mod;
    inv[M] = qp(jie[M],mod-2);
    for(int i = M-1;i>=0;i--)inv[i] = inv[i+1]*(i+1)%mod;
}
int n,m;
struct BIT{
    int tree[500005];
    void clear(){for(int i = 0;i<=n+2;i++)tree[i] = 0;}
    void upd(int pos,int add){
        for(int i = pos;i<=n+2;i+=lowbit(i))tree[i]+=add;
    }
    int query(int pos){
        int res = 0;
        for(int i = pos;i>0;i-=lowbit(i))res+=tree[i];
        return res;
    }
}T;
int fa[1000005],a[1000005];
vector<int>p[1000005];
bool OK = 1;
int rt;
int ls[1000005],rs[1000005],rnd[1000005],sz[1000005];
int lst[1000005],nxt[1000005];
int val[1000005];//求出的一组解 
void push_up(int k){sz[k] = sz[ls[k]]+sz[rs[k]]+1;}
void split_(int now,int k,int& x,int& y){
    if(!now){x = y = 0;return;}
    if(k>=sz[ls[now]]+1){
        x = now;
        split_(rs[now],k-sz[ls[now]]-1,rs[now],y);
    }else{
        y = now;
        split_(ls[now],k,x,ls[now]);
    }
    push_up(now);
}
int merge(int x,int y){
    if(!x or !y)return x|y;
    if(rnd[x]<rnd[y]){
        rs[x] = merge(rs[x],y);
        push_up(x);
        return x;
    }else{
        ls[y] = merge(x,ls[y]);
        push_up(y);
        return y;
    }
} 
int find_first(int k){while(ls[k])k = ls[k];return k;}
int find_last(int k){while(rs[k])k = rs[k];return k;} 
void add(int id){
    //插入到第 aid 个位置后
    int x,y;
    split_(rt,a[id]+1,x,y);//前面有个 0,多加 1  
    lst[id] = find_last(x),nxt[id] = find_first(y);
    assert(x and y);
    rt = merge(merge(x,id),y);
}
void del(int id){
    int x,y,z;
    split_(rt,a[id]+1,x,y);
    split_(y,1,y,z);
    rt = merge(x,z);
}
int nw = 0;
int dfn[500005],ssz[500005],b[500005];
void dfs(int now,int d){
    if(a[now]>d){OK = 0;return;}
    dfn[now] = ++nw;
    ssz[now] = 1;
    b[nw] = now;
    add(now);
    for(auto x:p[now])dfs(x,d+1),ssz[now]+=ssz[x];
    del(now);
}
int in[500005];
vector<int>pp[500005];
void add(int x,int y){pp[x].push_back(y);in[y]++;}
void topo(){
    queue<int>q;
    for(int i = 1;i<=n+2;i++)if(!in[i])q.push(i);
    int cc = 0;
    while(!q.empty()){
        int now = q.front();q.pop();
        val[now] = ++cc;
        for(auto x:pp[now])if(--in[x] == 0)q.push(x);
    }
    assert(cc == n+2);
    //求出一组解 
}
int pro = 1,tot = 0;
int l1[1500005],r1[1500005],l2[1500005],r2[1500005],ans[1500005];
bool f[1500005];
vector<pair<int,int> >ll[500005];
void add(int L1,int R1,int L2,int R2,int F){
    ++tot;
    l1[tot] = L1,r1[tot] = R1,l2[tot] = L2,r2[tot] = R2,f[tot] = F; 
}
void work(){
    for(int i = 1;i<=tot;i++)ll[l1[i]-1].push_back({i,-1}),ll[r1[i]].push_back({i,1});
    int sum = 0;
    for(int i = 1;i<=n;i++){
        //加入 bi
        T.upd(val[b[i]],1); 
        for(auto x:ll[i]){
            int id = x.first,f = x.second;
            ans[id] += f*(T.query(r2[id])-T.query(l2[id]-1));
        }
    }
}
void solve(){
    for(int i = 0;i<=tot;i++)ans[i] = 0;
    T.clear();nw = 0;tot = 0;
    for(int i = 0;i<=n+2;i++)in[i] = 0,p[i].clear(),pp[i].clear(),ll[i].clear();
    cin >> n;
    for(int i = 2;i<=n;i++)cin>>fa[i];
    for(int i = 1;i<=n;i++)cin>>a[i];
    for(int i = 2;i<=n;i++)p[fa[i]].push_back(i);
    for(int i = 1;i<=n+2;i++)ls[i] = rs[i] = 0,sz[i] = 1,rnd[i] = rand();
    rt = merge(n+2,n+1);
    OK = 1;pro = jie[n];
    dfs(1,0);   
    for(int i = 1;i<=n;i++)add(lst[i],i),add(i,nxt[i]);
    topo();
    if(!OK){cout << 0 << '\n';return;}
    for(int i = 1;i<=n;i++){
        add(dfn[i],dfn[i]+ssz[i]-1,val[lst[i]],val[nxt[i]],0);
        add(dfn[i]+1,dfn[i]+ssz[i]-1,val[lst[i]],val[i],1);
        add(dfn[i]+1,dfn[i]+ssz[i]-1,val[i],val[nxt[i]],1);
        // i 本身不计入 
    }   
    work();
    for(int i = 1;i<=tot;i++){
        if(f[i])pro = pro*jie[ans[i]]%mod;
        else pro = pro*inv[ans[i]]%mod;
    }
    cout << pro << '\n';
}
signed main(){
    srand(time(0));
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    init();
    int t;cin >> t;
    while(t--)solve();
    return 0;
}