P10241 [THUSC 2021] 白兰地厅的西瓜 - Solution

· · 题解

双倍经验

简单来说就是求树上不钦定起点结尾的 LIS,令 sub_xx 子树的集合。

这是经典线段树合并。

显然枚举 LIS 端点 s,\,t 的 LCA,比如是 u

那么 LIS 可以拆成 u 的某子树内,向上到 u,然后向下到 u 的另外一棵子树。

枚举 t 所在的子树 v,每次枚举完 v 之后令 S \leftarrow S \cup sub_v,有两种情况:

于是用动态开点线段树维护结尾权值是 x 的向根最长上升链,维护开头权值是 y 的向叶最长上升链。线段树合并即可,时间复杂度 \Theta(n \log n)

另外这题评蓝是不是评的太低了点。

#include <bits/stdc++.h>
#define X first
#define Y second
using namespace std;
typedef long long int ll;
using pii = pair<int, int>;
const int maxn = 2e5 + 10;
constexpr int mod = 1e9 + 7, S = 1e9 + 10;
struct edge { int to, nxt; } nd[maxn << 1]; int h[maxn], cnt = 0, rt[maxn];
inline void add(int u, int v) { nd[cnt].nxt = h[u], nd[cnt].to = v, h[u] = cnt++; }
struct Node {
    int l, r, f, g;
    Node() { l = r = f = g = 0; }
} t[maxn << 5]; int ans = 0;
#define ls(x) (t[x].l)
#define rs(x) (t[x].r)
#define f(x) (t[x].f)
#define g(x) (t[x].g)
#define mid (l + r >> 1)
int n, a[maxn], tot = 0;
void ins(int l, int r, int d, int v, int& x, int tp) {
    if (!x) x = ++cnt; 
    tp == 1 ? f(x) = max(f(x), v) : g(x) = max(g(x), v); if (l == r) return;
    d <= mid ? ins(l, mid, d, v, ls(x), tp) : ins(mid + 1, r, d, v, rs(x), tp);
}
Node qry(int l, int r, int ql, int qr, int x) {
    if (ql <= l && r <= qr) return t[x]; Node u, L, R;
    if (ql <= mid) L = qry(l, mid, ql, qr, ls(x)), u.f = L.f, u.g = L.g;
    if (qr > mid) R = qry(mid + 1, r, ql, qr, rs(x)), u.f = max(u.f, R.f), u.g = max(u.g, R.g);
    return u;
}
int merge(int x, int y) {
    if (!x || !y) return x | y;
    else {
        ans = max({ ans, g(ls(x)) + f(rs(y)), g(ls(y)) + f(rs(x)) });
        f(x) = max(f(x), f(y)), g(x) = max(g(x), g(y));
        ls(x) = merge(ls(x), ls(y)), rs(x) = merge(rs(x), rs(y));
        return x;
    }
}
void dfs(int u, int pa) {
    int st = 0, ed = 0;
    for (int i = h[u]; ~i; i = nd[i].nxt) {
        int v = nd[i].to;
        if (v == pa) continue; dfs(v, u);
        int L = qry(0, S, a[u] + 1, S, rt[v]).f, R = qry(0, S, 0, a[u] - 1, rt[v]).g;
        st = max(st, L), ed = max(ed, R);
        ans = max(ans, qry(0, S, a[u] + 1, S, rt[u]).f + R + 1);
        ans = max(ans, qry(0, S, 0, a[u] - 1, rt[u]).g + L + 1);
        rt[u] = merge(rt[u], rt[v]);
    }
    ins(0, S, a[u], st + 1, rt[u], 1);
    ins(0, S, a[u], ed + 1, rt[u], 2);
}
int main() {
    memset(h, -1, sizeof(h));
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v); add(v, u);
    }
    dfs(1, 0);
    printf("%d\n", ans);
    return 0;
}