题解 P10226【[COCI 2023/2024 #3] Restorani】

· · 题解

神秘题。

首先如果把餐厅记作 +1,将甜品店记作 -1,那么如果一条边的某一侧的数字和为 x(x>0),那么这条边至少需要经过 2x 次。

x=0,如果这条边的两侧都有关键点(1 号点或者餐厅或者甜品店),那么这条边仍然至少需要经过 2 次。

否则,这条边可以不被经过。

但是怎么证明这样做是对的呢?我们考虑一个神秘的构造方法:如果一个点的子树里 +- 多,那么可以把该子树内的 +- 排列成若干个形如 +-+\cdots+ 的段(称为 + 段),每段后留一个空位。

类似的,如果一个点的子树里 -+ 多,那么可以把该子树内的 +- 排列成若干个形如 -+-\cdots- 的段(称为 - 段),每段前留一个空位。

如果 +- 一样多,那么可以排成单独的一个形如 +-+-\cdots+- 的段(称为 0 段)。

对于子树信息的合并,我们只需要将 + 段和 - 段交错合并即可,而 0 段可以和任意一个 +,-,0 段合并,直至合并成若干 + 段或若干 - 段或一个 0 段。

容易证明,使用这种方法可以使得每条边被经过的次数都取得最小值。

实现的时候,可以用链表套链表,外层链表维护一个子树包含哪些段,内层链表维护段内的 + 点和 - 点的编号。

时间复杂度为 \mathcal{O}(n)

代码如下:

#include <bits/stdc++.h>
using namespace std;
const int _ = 3e5 + 10;
int n, m, arr[_], brr[_], pos[_], neg[_], cntpos[_], cntneg[_], pnex[_], nnex[_], e, hd[_], nx[600010], to[600010], lef[_], rig[_], ans[_], bns[_];
int pcnt, ptop, pbin[_], pl[_], pr[_], pt[_];
int ncnt, ntop, nbin[_], nl[_], nr[_], nt[_];
long long len;
inline int pget(void) {
    if (ptop) {
        return pbin[ptop--];
    } else {
        return ++pcnt;
    }
}
inline void pdel(int x) {
    pbin[++ptop] = x;
}
inline int nget(void) {
    if (ntop) {
        return nbin[ntop--];
    } else {
        return ++ncnt;
    }
}
inline void ndel(int x) {
    nbin[++ntop] = x;
}
inline void pmerge(int& L, int& R, int l, int r) {
    if (l == 0 && r == 0) {
    } else if (L == 0 && R == 0) {
        L = l;
        R = r;
    } else {
        pt[R] = l;
        R = r;
    }
}
inline void nmerge(int& L, int& R, int l, int r) {
    if (l == 0 && r == 0) {
    } else if (L == 0 && R == 0) {
        L = l;
        R = r;
    } else {
        nt[R] = l;
        R = r;
    }
}
inline void mmerge(int& L, int& R, int l, int r) {
    if (l == 0 && r == 0) {
    } else if (L == 0 && R == 0) {
        L = l;
        R = r;
    } else {
        nnex[R] = l;
        R = r;
    }
}
inline void three_way_merge(int& L, int& R, int lp, int rp, int ln, int rn, int lm, int rm) {
    if (lp == 0 && rp == 0 && ln == 0 && rn == 0) {
        L = lm;
        R = rm;
    } else if (lp == 0 && rp == 0) {
        if (!(lm == 0 && rm == 0)) {
            nnex[nr[rn]] = lm;
            nr[rn] = rm;
        }
        L = ln;
        R = rn;
    } else if (ln == 0 && rn == 0) {
        if (!(lm == 0 && rm == 0)) {
            nnex[rm] = pl[rp];
            pl[rp] = lm;
        }
        L = lp;
        R = rp;
    } else {
        while (lp != -1 && ln != -1) {
            pnex[pr[lp]] = nl[ln];
            if (lm == 0 && rm == 0) {
                lm = pl[lp];
                rm = nr[ln];
            } else {
                nnex[rm] = pl[lp];
                rm = nr[ln];
            }
            pdel(lp);
            ndel(ln);
            lp = pt[lp];
            ln = nt[ln];
        }
        if (lp == -1 && ln == -1) {
            L = lm;
            R = rm;
        } else if (lp == -1) {
            nnex[nr[rn]] = lm;
            nr[rn] = rm;
            L = ln;
            R = rn;
        } else {
            nnex[rm] = pl[rp];
            pl[rp] = lm;
            L = lp;
            R = rp;
        }
    }
}
inline void add(int u, int v) {
    e++;
    nx[e] = hd[u];
    to[e] = v;
    hd[u] = e;
}
void dfs(int x, int f) {
    if (pos[x]) cntpos[x]++;
    if (neg[x]) cntneg[x]++;
    for (int i = hd[x]; i; i = nx[i]) {
        int y = to[i];
        if (y == f) continue;
        dfs(y, x);
        cntpos[x] += cntpos[y];
        cntneg[x] += cntneg[y];
        if (cntpos[y] > cntneg[y]) {
            len += cntpos[y] - cntneg[y];
        } else if (cntneg[y] > cntpos[y]) {
            len += cntneg[y] - cntpos[y];
        } else if (cntpos[y]) {
            len++;
        }
    }
}
void solve(int x, int f) {
    int lp = 0, rp = 0;
    int ln = 0, rn = 0;
    int lm = 0, rm = 0;
    if (pos[x]) {
        int a = pget();
        lp = rp = a;
        pl[a] = pos[x];
        pr[a] = pos[x];
        pt[a] = -1;
    }
    if (neg[x]) {
        int a = nget();
        ln = rn = a;
        nl[a] = neg[x];
        nr[a] = neg[x];
        nt[a] = -1;
    }
    for (int i = hd[x]; i; i = nx[i]) {
        int y = to[i];
        if (y == f) continue;
        solve(y, x);
        if (cntpos[y] > cntneg[y]) {
            pmerge(lp, rp, lef[y], rig[y]);
        } else if (cntpos[y] < cntneg[y]) {
            nmerge(ln, rn, lef[y], rig[y]);
        } else {
            mmerge(lm, rm, lef[y], rig[y]);
        }
    }
    three_way_merge(lef[x], rig[x], lp, rp, ln, rn, lm, rm);
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= m; i++) {
        cin >> arr[i];
        pos[arr[i]] = i;
    }
    for (int i = 1; i <= m; i++) {
        cin >> brr[i];
        neg[brr[i]] = i;
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        add(u, v);
        add(v, u);
    }
    dfs(1, 0);
    (len <<= 1LL);
    solve(1, 0);
    ans[1] = lef[1];
    bns[1] = pnex[ans[1]];
    for (int i = 2; i <= m; i++) {
        ans[i] = nnex[bns[i-1]];
        bns[i] = pnex[ans[i]];
    }
    cout << len << endl;
    for (int i = 1; i <= m; i++) {
        cout << ans[i] << ' ' << bns[i];
        if (i != m) cout << ' ';
    }
    cout << endl;
    return 0;
}