题解 P5803 【[SEERC2019] Tree Permutations】

· · 题解

首先对 a 排序。

证明:已知 p_i < i,但此时构造出的 p_{i + 1} > i,则不合法。

证明:此时任何从 \leq i 的点到 > i 的点的路径都要经过 i

证明:由上一条以及 p_i < i 可得。

证明:我们可以构造如下一棵树:

而我们构造出的树同样满足权值最大的性质,于是我们维护一棵支持单点删除、求前 k 大的线段树即可。时间复杂度为 O(n \log n)

代码:

#include <iostream>
#include <algorithm>
#include <functional>

using namespace std;

typedef long long ll;

typedef struct {
    int l;
    int r;
    int cnt;
    ll sum;
} Node;

int a[200007], b[200007], diff[100007];
bool mark[100007], vis[100007];
Node tree[800007];

inline void update(int x){
    int ls = x * 2, rs = x * 2 + 1;
    tree[x].cnt = tree[ls].cnt + tree[rs].cnt;
    tree[x].sum = tree[ls].sum + tree[rs].sum;
}

void build(int x, int l, int r){
    tree[x].l = l;
    tree[x].r = r;
    if (l == r){
        tree[x].cnt = 1;
        tree[x].sum = b[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(x * 2, l, mid);
    build(x * 2 + 1, mid + 1, r);
    update(x);
}

void erase(int x, int pos){
    if (tree[x].l == tree[x].r){
        tree[x].cnt = tree[x].sum = 0;
        return;
    }
    if (pos <= ((tree[x].l + tree[x].r) >> 1)){
        erase(x * 2, pos);
    } else {
        erase(x * 2 + 1, pos);
    }
    update(x);
}

ll get_sum(int x, int k){
    if (tree[x].cnt == k) return tree[x].sum;
    int ls = x * 2;
    if (k <= tree[ls].cnt) return get_sum(ls, k);
    return get_sum(x * 2 + 1, k - tree[ls].cnt) + tree[ls].sum;
}

int main(){
    int n, m, mink = 0, cnt = 0, maxk = 0;
    cin >> n;
    m = (n - 1) * 2;
    for (register int i = 1; i <= m; i++){
        cin >> a[i];
    }
    sort(a + 1, a + m + 1);
    for (register int i = 1; i <= m; i++){
        if (a[i] > i){
            for (register int j = 1; j < n; j++){
                cout << -1 << " ";
            }
            return 0;
        }
        if (a[i] == i){
            mink++;
            mark[a[i]] = true;
        } else {
            b[++cnt] = a[i];
        }
        if (!vis[a[i]]){
            vis[a[i]] = true;
            maxk++;
        }
    }
    if (mink >= n || mink > maxk){
        for (register int i = 1; i < n; i++){
            cout << -1 << " ";
        }
        return 0;
    }
    reverse(b + 1, b + cnt + 1);
    for (register int i = cnt, j = 0; i >= 1; i--){
        if (!mark[b[i]] && b[i] != b[i + 1]) diff[++j] = i;
    }
    build(1, 1, cnt);
    for (register int i = 1; i < mink; i++){
        cout << -1 << " ";
    }
    for (register int i = mink; i <= maxk; i++){
        if (i > mink) erase(1, diff[i - mink]);
        cout << get_sum(1, i) << " ";
    }
    for (register int i = maxk + 1; i < n; i++){
        cout << -1 << " ";
    }
    return 0;
}