题解 P9103 [PA2020] Bardzo skomplikowany test

· · 题解

考虑如果第一棵树的根是 a,第二棵树的根是 b。不妨假设 a<b,另一个方向对称。

此时我们为了把 b 移动到根,我们显然需要将 a+1\sim b-1 的所有点移动到 a,b 之间,形成一条 a-(a+1)-(a+2)-\cdots-(b-1)-b 的链,然后将 b 转到根。

此时树根的两侧是独立子问题,递归做即可。

为什么这么做是最优的?

考虑:

  1. 如果考虑树的父亲序列,则题目中的操作一定只会同时改变中序遍历上相邻的两个点的父亲。
  2. 复位根必须要求形成一条 a-(a+1)-(a+2)-\cdots-(b-1)-b 的链,从而一切会改变 a,b 以及中间的任何一个点的,且不属于构建这条链的过程的结构改变都将被覆盖,没有必要。
  3. 其余改变与构建链的过程独立,可以交换到构建链后面。

然后就是考虑如何构建这条链。

这个唯一性很强:从 a 开始递归往右走直到走到一个 \geq k 的点,设走到 a-c_1-c_2-\cdots-c_k。然后将 c_1 的左子树修改成一条左链,然后往右转直到左子树消失。然后对 c_2,\cdots,c_k 同样处理。

然后再考虑如何把一个子树修改成左链/右链。(右链是因为需要对称)

这个更直接,把左子树修改成左链,右子树修改成右链,然后依照需求往左/往右转。

这个修改次数就可以直接 dp 了。设 f_{u,0/1} 表示修改成左链/右链的方案数,s_uu 的子树的大小。转移:

其中 l,ru 的左右孩子。

考虑算答案。

考虑我们复位根之后剩下的是包含根的一条链,两头接上 原本就在 a 中存在 的子树。

所以我们就可以将 a 的形态记作一棵子树上面接一条左链/右链,且这个链上的点是一个连续区间。

然后注意到上面那个“从 a 开始递归往右走直到走到一个 \geq k 的点”的过程不会多次搜到同一个点,于是直接暴力做,复杂度就是 O(n),就做完了。

const int N = 500005;
const long long mod = 1000000007;

int n, als[N], ars[N], bls[N], brs[N], alt[N], blt[N], art[N], brt[N], fa[N], fb[N], rta, rtb;
long long dpa[N][2], dpb[N][2];

inline void Read() {
    n = qread();
    for (int i = 1;i <= n;i++) {
        fa[i] = qread();
        if (fa[i] != -1 && fa[i] < i) ars[fa[i]] = i;
        else if (fa[i] != -1 && fa[i] > i) als[fa[i]] = i;
        else rta = i;
    }
    for (int i = 1;i <= n;i++) {
        fb[i] = qread();
        if (fb[i] != -1 && fb[i] < i) brs[fb[i]] = i;
        else if (fb[i] != -1 && fb[i] > i) bls[fb[i]] = i;
        else rtb = i;
    }
}

inline void DfsA(int u) {
    alt[u] = art[u] = u;
    dpa[u][0] = dpa[u][1] = 0;
    if (als[u]) {
        DfsA(als[u]);
        alt[u] = min(alt[u], alt[als[u]]);
        dpa[u][0] += dpa[als[u]][0];
        dpa[u][1] += art[als[u]] - alt[als[u]] + 1 + dpa[als[u]][0];
    }
    if (ars[u]) {
        DfsA(ars[u]);
        art[u] = max(art[u], art[ars[u]]);
        dpa[u][0] += art[ars[u]] - alt[ars[u]] + 1 + dpa[ars[u]][1];
        dpa[u][1] += dpa[ars[u]][1];
    }
}

inline void DfsB(int u) {
    blt[u] = brt[u] = u;
    dpb[u][0] = dpb[u][1] = 0;
    if (bls[u]) {
        DfsB(bls[u]);
        blt[u] = min(blt[u], blt[bls[u]]);
        dpb[u][0] += dpb[bls[u]][0];
        dpb[u][1] += brt[bls[u]] - blt[bls[u]] + 1 + dpb[bls[u]][0];
    }
    if (brs[u]) {
        DfsB(brs[u]);
        brt[u] = max(brt[u], brt[brs[u]]);
        dpb[u][0] += brt[brs[u]] - blt[brs[u]] + 1 + dpb[brs[u]][1];
        dpb[u][1] += dpb[brs[u]][1];
    }
    dpb[u][0] %= mod; dpb[u][1] %= mod;
}

inline long long Solve(int nda, int sgl, int sgr, int flg, int ndb) {
    if (sgl > sgr) {
        sgl = sgr = -1;
        flg = 0;
    }
    long long res = -1;
    if (flg == 0) {
        if (nda == ndb) {
            long long ans = 0;
            if (als[nda]) ans += Solve(als[nda], -1, -1, 0, bls[ndb]);
            if (ars[nda]) ans += Solve(ars[nda], -1, -1, 0, brs[ndb]);
            res = ans % mod;
        } else if (nda < ndb) {
            int u = nda;
            long long ans = 0;
            while (u < ndb) {
                u = ars[u];
                if (als[u]) ans += dpa[als[u]][0] + art[als[u]] - alt[als[u]] + 1;
            }
            int l = nda, r = u;
            ans += ndb - nda;
            if (als[nda]) ans += Solve(als[nda], l, ndb - 1, 2, bls[ndb]);
            else ans += dpb[bls[ndb]][0];
            if (ars[u]) ans += Solve(ars[u], ndb + 1, r, 1, brs[ndb]);
            else ans += dpb[brs[ndb]][1];
            res = ans % mod;
        } else if (nda > ndb) {
            int u = nda;
            long long ans = 0;
            while (u > ndb) {
                u = als[u];
                if (ars[u]) ans += dpa[ars[u]][1] + art[ars[u]] - alt[ars[u]] + 1;
            }
            int l = u, r = nda;
            ans += nda - ndb;
            if (als[u]) ans += Solve(als[u], l, ndb - 1, 2, bls[ndb]);
            else ans += dpb[bls[ndb]][0];
            if (ars[nda]) ans += Solve(ars[nda], ndb + 1, r, 1, brs[ndb]);
            else ans += dpb[brs[ndb]][1];
            res = ans % mod;
        }
    } else if (flg == 1) {
        if (sgl <= ndb && ndb <= sgr) res = (ndb - sgl + dpb[bls[ndb]][0] + Solve(nda, ndb + 1, sgr, 1, brs[ndb])) % mod;
        else {
            int u = nda;
            long long ans = 0;
            while (u < ndb) {
                if (als[u]) ans += dpa[als[u]][0] + art[als[u]] - alt[als[u]] + 1;
                u = ars[u];
            }
            if (als[u]) ans += dpa[als[u]][0] + art[als[u]] - alt[als[u]] + 1;
            u = ars[u];
            if (u == 0) res = (ans + dpb[ndb][1]) % mod;
            else res = (ans + Solve(u, sgl, fa[u], 1, ndb)) % mod;
        }
    } else {
        if (sgl <= ndb && ndb <= sgr) res = (sgr - ndb + dpb[brs[ndb]][1] + Solve(nda, sgl, ndb - 1, 2, bls[ndb])) % mod;
        else {
            int u = nda;
            long long ans = 0;
            while (u > ndb) {
                if (ars[u]) ans += dpa[ars[u]][1] + art[ars[u]] - alt[ars[u]] + 1;
                u = als[u];
            }
            if (ars[u]) ans += dpa[ars[u]][1] + art[ars[u]] - alt[ars[u]] + 1;
            u = als[u];
            if (u == 0) res = (ans + dpb[ndb][0]) % mod;
            else res = (ans + Solve(u, fa[u], sgr, 2, ndb)) % mod;
        }
    }
    return (res % mod + mod) % mod;
}