P10241 [THUSC 2021] 白兰地厅的西瓜 - Solution
双倍经验
简单来说就是求树上不钦定起点结尾的 LIS,令
这是经典线段树合并。
显然枚举 LIS 端点
那么 LIS 可以拆成
枚举
于是用动态开点线段树维护结尾权值是
另外这题评蓝是不是评的太低了点。
#include <bits/stdc++.h>
#define X first
#define Y second
using namespace std;
typedef long long int ll;
using pii = pair<int, int>;
const int maxn = 2e5 + 10;
constexpr int mod = 1e9 + 7, S = 1e9 + 10;
struct edge { int to, nxt; } nd[maxn << 1]; int h[maxn], cnt = 0, rt[maxn];
inline void add(int u, int v) { nd[cnt].nxt = h[u], nd[cnt].to = v, h[u] = cnt++; }
struct Node {
int l, r, f, g;
Node() { l = r = f = g = 0; }
} t[maxn << 5]; int ans = 0;
#define ls(x) (t[x].l)
#define rs(x) (t[x].r)
#define f(x) (t[x].f)
#define g(x) (t[x].g)
#define mid (l + r >> 1)
int n, a[maxn], tot = 0;
void ins(int l, int r, int d, int v, int& x, int tp) {
if (!x) x = ++cnt;
tp == 1 ? f(x) = max(f(x), v) : g(x) = max(g(x), v); if (l == r) return;
d <= mid ? ins(l, mid, d, v, ls(x), tp) : ins(mid + 1, r, d, v, rs(x), tp);
}
Node qry(int l, int r, int ql, int qr, int x) {
if (ql <= l && r <= qr) return t[x]; Node u, L, R;
if (ql <= mid) L = qry(l, mid, ql, qr, ls(x)), u.f = L.f, u.g = L.g;
if (qr > mid) R = qry(mid + 1, r, ql, qr, rs(x)), u.f = max(u.f, R.f), u.g = max(u.g, R.g);
return u;
}
int merge(int x, int y) {
if (!x || !y) return x | y;
else {
ans = max({ ans, g(ls(x)) + f(rs(y)), g(ls(y)) + f(rs(x)) });
f(x) = max(f(x), f(y)), g(x) = max(g(x), g(y));
ls(x) = merge(ls(x), ls(y)), rs(x) = merge(rs(x), rs(y));
return x;
}
}
void dfs(int u, int pa) {
int st = 0, ed = 0;
for (int i = h[u]; ~i; i = nd[i].nxt) {
int v = nd[i].to;
if (v == pa) continue; dfs(v, u);
int L = qry(0, S, a[u] + 1, S, rt[v]).f, R = qry(0, S, 0, a[u] - 1, rt[v]).g;
st = max(st, L), ed = max(ed, R);
ans = max(ans, qry(0, S, a[u] + 1, S, rt[u]).f + R + 1);
ans = max(ans, qry(0, S, 0, a[u] - 1, rt[u]).g + L + 1);
rt[u] = merge(rt[u], rt[v]);
}
ins(0, S, a[u], st + 1, rt[u], 1);
ins(0, S, a[u], ed + 1, rt[u], 2);
}
int main() {
memset(h, -1, sizeof(h));
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v); add(v, u);
}
dfs(1, 0);
printf("%d\n", ans);
return 0;
}