题解:P13804 [SWERC 2023] In-order

· · 题解

首先要说的是本人博客膜拜评论会被删,请点评论区上方“点赞”按钮作为替代,感谢各位支持。

题意

某人有一棵大小为 n 的二叉树,但他不告诉你树,而告诉你这棵树的前序遍历、后序遍历、以及中序遍历的一个片段,问树有几种可能?

已知的信息

我们先试图找出先序遍历和后序遍历能确定什么。我们发现:

  1. 对于每个子树,所在段第一个为根。
  2. 所在段先序遍历的第二个元素和后序遍历的倒数第二个元素为根的儿子信息。相同则根有一个儿子,需要统计儿子为左儿子或右儿子的情况,不同则为两个儿子,先序遍历的第二个元素为左儿子,后序遍历的倒数第二个元素为右儿子。

我们通过先序遍历和后续遍历已得出了海量的信息,这意味着题目并非困难,而我们要做的是有一个儿子的情况下儿子为左儿子或右儿子的影响。

我们发现:儿子为左儿子时,节点在子树的中序遍历的最后一项出现,反之,右儿子意味着节点子树在中序遍历最前方,并且会使子树上其他节点在中序遍历的出现项数加一。

这就意味着节点在第几项与如图的红点相关:

也就是说,像平衡树查找一样,当走向右儿子时将根与左子树大小加上,最后加上这个点和它的左子树大小,就是一个点在中序遍历中的位置。

接下来我们要进行分类讨论。

分类讨论

接下来用到的变量如下:

分类:

对于情况三去确定儿子信息的方式如下:

(如图,红边确定为左儿子,蓝边确定为右儿子,若下一个点为祖先对应 x\to z,这个点为下一个点祖先对应 z \to y,两点间必为祖先和儿子关系)

我们依此更新各参数,注意到暴力跳链至多跳 \operatorname{O}(n) 次(可以把涉及的点提出来形成树,发现每条边都至多在跳链中算两次),我们可以暴力跳链更新确定的情况,然后重新深搜全树重算 cnt,cntv_i,mrnk_i。然后用类似于第二种情况的算法将开头的项数定在 pl

若不想更新 cntv_i,可以发现我们能在跳链的过程中算出所有中序遍历涉及的点的最近公共祖先 y,而修改后 cntv_{plv} = cntv_{y},这样我们就算出了答案。

代码

#include<bits/stdc++.h>
using namespace std;
int n,q[500009],dep[500009],fa[500009],h[500009],fr[500009],lc[500009],rc[500009],z[500009],cnt,cntv[500009],sz[500009],mrnk[500009];
int plq[500009],plh[500009],cst;
bool fix[500009];
void build_tree(int l,int r,int l1,int r1){
    fr[q[l]] = r;
    sz[q[l]] = r - l + 1;
    if(l == r)
        return;
    if(q[l + 1] == h[r1 - 1]){
        cnt ++;
        lc[q[l]] = q[l + 1];
        dep[q[l + 1]] = dep[q[l]] + 1;
        mrnk[q[l + 1]] = mrnk[q[l]];
        cntv[q[l + 1]] = cntv[q[l]] + 1;
        fa[q[l + 1]] = q[l]; 
        build_tree(l + 1,r,l1,r1 - 1);
    }
    else{
        lc[q[l]] = q[l + 1];
        rc[h[r1]] = h[r1 - 1];
        dep[h[r1 - 1]] = dep[q[l + 1]] = dep[q[l]] + 1;
        fa[q[l + 1]] = fa[h[r1 - 1]] = q[l];
        mrnk[q[l + 1]] = mrnk[q[l]];
        mrnk[q[l]] = mrnk[q[l]] + plq[h[r1 - 1]] - l - 1;
        mrnk[h[r1 - 1]] = mrnk[q[l]] + 1;
        cntv[q[l + 1]] = cntv[h[r1 - 1]] = cntv[q[l]];
        build_tree(l + 1,plq[h[r1 - 1]] - 1,l1,plh[q[l + 1]]);
        build_tree(plq[h[r1 - 1]],r,plh[q[l + 1]] + 1,r1 - 1);
    }
}
void show(int x){
    if(x == 0)
        return;
    if(lc[x] != 0){
        mrnk[lc[x]] = mrnk[x],show(lc[x]);
        if(fix[x] || rc[x] != 0)
            mrnk[x] += sz[lc[x]];
    }
//  printf("%d %d\n",x,mrnk[x]);
    if(rc[x])
        mrnk[rc[x]] = mrnk[x] + 1,show(rc[x]);
}
int ox[500009],s[500009],ni[500009],inv[500009];
const int mod = 999999937;
char gc(){
    static char buf[1919810],*p1 = buf,*p2 = buf;
    return (p1 == p2) && (p2 = (p1 = buf) + fread(buf,1,1919809,stdin),p1 == p2) ? EOF : *p1 ++;
}
void fin(int &x){
    x = 0;
    char c = gc();
    while(c > '9' || c < '0')
        c = gc();
    while(c >= '0' && c <= '9')
        x = (x << 1) + (x << 3) + (c ^ 48),c = gc();
}
int main(){
    //freopen("tree7.in","r",stdin);
    fin(n);
    for(int i = 1; i <= n; i ++){
        fin(q[i]);
        plq[q[i]] = i;
    }
    for(int i = 1; i <= n; i ++){
        fin(h[i]);
        plh[h[i]] = i;
    }
    build_tree(1,n,1,n);
//  show(q[1]);
    int pl = 0;
    for(int i = 1; i <= n; i ++){
        fin(z[i]);
        cst += z[i] > 0;
        if(z[i] > 0 && pl == 0)
            pl = i;
    }
    ox[0] = s[0] = ni[0] = s[1] = ni[1] = inv[1] = 1;
    ox[1] = 2;
    for(int i = 2; i <= n; i ++){
        ox[i] = (ox[i - 1] << 1) % mod;
        s[i] = 1ll * s[i - 1] * i % mod;
        inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
        ni[i] = 1ll * inv[i] * ni[i - 1] % mod;
    } 
    //printf("%d %d\n",cst,z[pl]);
    if(cst == 0){
        printf("%d\n",ox[cnt]);
    }
    else if(cst == 1){
        int o = 0;
        //printf("%d %d %d %d %d\n",pl,mrnk[z[pl]],cnt,cntv[z[pl]],(lc[z[pl]] && !rc[z[pl]]));
        if(pl - mrnk[z[pl]] - 1 <= cntv[z[pl]])
            o = 1ll * ox[cnt - cntv[z[pl]] - (lc[z[pl]] && !rc[z[pl]])] * s[cntv[z[pl]]] % mod * ni[pl - mrnk[z[pl]] - 1] % mod * ni[cntv[z[pl]] - (pl - mrnk[z[pl]] - 1)] % mod;
        if(lc[z[pl]] > 0 && rc[z[pl]] == 0 && pl - mrnk[z[pl]] - 1 - sz[lc[z[pl]]] >= 0)
            o = (o + 1ll * ox[cnt - cntv[z[pl]] - (lc[z[pl]] && !rc[z[pl]])] * s[cntv[z[pl]]] % mod * ni[pl - mrnk[z[pl]] - 1 - sz[lc[z[pl]]]] % mod * ni[cntv[z[pl]] - (pl - mrnk[z[pl]] - 1) + sz[lc[z[pl]]]] % mod) % mod;
        printf("%d\n",o);
    }
    else{
        int nw = z[pl];
        for(int i = pl; z[i + 1] > 0; i ++){
            int x = z[i],y = z[i + 1],a = z[i],b = z[i + 1];
            while(x != y){
                if(dep[x] < dep[y]){
                    y = fa[y];
                }
                else{
                    x = fa[x];
                }
            }
            if(dep[nw] > dep[x])
                nw = x;
            if(z[i] != x && !fix[z[i]] && lc[z[i]] != 0 && rc[z[i]] == 0){
                cnt --;
                fix[z[i]] = true;
            }
            if(z[i + 1] != x && !fix[z[i + 1]] && lc[z[i + 1]] != 0 && rc[z[i + 1]] == 0){
                fix[z[i + 1]] = true;
                cnt --;
                swap(lc[z[i + 1]],rc[z[i + 1]]);
            }
            while(a != x){
                if(fa[a] != x){
                    if(!fix[fa[a]] && rc[fa[a]] == 0){
                        cnt --;
                        swap(lc[fa[a]],rc[fa[a]]);
                        fix[fa[a]] = true;
                    }
                }
                else{
                    if(!fix[fa[a]] && rc[fa[a]] == 0){
                        cnt --;
                        fix[fa[a]] = true;
                    }
                }
                a = fa[a];
            }
            while(b != x){
                if(fa[b] == x){
                    if(!fix[fa[b]] && rc[fa[b]] == 0){
                        cnt --;
                        swap(lc[fa[b]],rc[fa[b]]);
                        fix[fa[b]] = true;
                    }
                }
                else{
                    if(!fix[fa[b]] && rc[fa[b]] == 0){
                        cnt --;
                        fix[fa[b]] = true;
                    }
                }
                b = fa[b];
            }
        }
        mrnk[q[1]] = 0;
        show(q[1]);
    //  printf("%d %d %d %d %d\n",cnt,nw,cntv[nw],mrnk[z[pl]],pl);
        int o = 1ll * ox[cnt - cntv[nw]] * s[cntv[nw]] % mod * ni[pl - mrnk[z[pl]] - 1] % mod * ni[cntv[nw] - (pl - mrnk[z[pl]] - 1)] % mod;
        printf("%d\n",o);
    }
}