题解 CF490F Treeland Tour

· · 题解

很久以前做省选模拟赛的 T1 ,那时候数据范围是 2 \times 10 ^ 5 于是果断敲了个平衡树套启发式合并……现在学了线段树合并后感觉还是挺简单的。

我们考虑在枚举一个点时维护两种树上最长上升子序列:

对于第一种很简单,我们考虑维护以一个点结尾的最长上升子序列以及以一个点开头的最长下降子序列,那么我们合并时就直接选子树中两个最长的合法最长上升子序列和最长下降子序列即可,这些都可以线段树合并维护。

对于第二种情况,我们直接在线段树合并的时候从合并的两棵线段树里每层的左右子树里面选左边的最长上升子序列,右边的最长下降子序列合并即可,记得交换一下两棵树的顺序再更新一次答案。

#include "bits/stdc++.h"
using namespace std;
const int Len = 1e5 + 5;
int n,m,lsh[Len],head[Len],cnt,val[Len],ans,Cnt,maxnans[Len][2],Ans;
struct T_T
{
    int lc[Len * 80] , rc[Len * 80] , maxn[Len * 80][2],ans[Len * 80],rt[Len],tot;
    void push_up(int x)
    {   
        maxn[x][0] = max(maxn[lc[x]][0] , maxn[rc[x]][0]);
        maxn[x][1] = max(maxn[lc[x]][1] , maxn[rc[x]][1]);
    }
    int clone(){tot ++; lc[tot] = rc[tot] = maxn[tot][0] = maxn[tot][1] = 0 ; return tot;}
    int merge(int a,int b,int l,int r)
    {
        Ans = max(Ans , max(maxn[lc[a]][1] + maxn[rc[b]][0] , maxn[lc[b]][1] + maxn[rc[a]][0]));
        if(!a || !b) return a + b;
        if(l == r)
        {
            maxn[a][0] = max(maxn[a][0] , maxn[b][0]) ; maxn[a][1] = max(maxn[a][1] , maxn[b][1]);
            return a;
        }
        int mid = (l + r) >> 1;
        lc[a] = merge(lc[a] , lc[b] , l , mid) , rc[a] = merge(rc[a] , rc[b] , mid + 1 , r);
        push_up(a);
        return a;   
    } 
    void update(int p,int l,int r,int idx,int opt,int w)
    {   
        if(l == r) 
        {
            maxn[p][opt] = max(maxn[p][opt] , w);
            return;
        }
        int mid = (l + r) >> 1;
        if(idx <= mid) 
        {
            if(!lc[p]) lc[p] = clone();
            update(lc[p] , l , mid , idx , opt , w);
        }
        else 
        {
            if(!rc[p]) rc[p] = clone();
            update(rc[p] , mid + 1 , r , idx , opt , w);
        }
        push_up(p);
    }
    int query(int p,int l,int r,int nl,int nr,int opt)
    {
        if(nl <= l && nr >= r) return maxn[p][opt];
        int mid = (l + r) >> 1 , res = 0;
        if(nl <= mid) 
        {
            if(!lc[p]) lc[p] = clone();
            res = query(lc[p] , l , mid , nl , nr , opt);
        }
        if(nr > mid) 
        {
            if(!rc[p]) rc[p] = clone();
            res = max(res , query(rc[p] , mid + 1 , r , nl , nr , opt));
        }
        return res;
    }
}S1;
struct node
{
    int next,to;
}edge[Len << 1];
struct Node
{
    int x,y;
}fos[Len];
bool cmpx(Node x,Node y){return x.x < y.x;}
bool cmpy(Node x,Node y){return x.y < y.y;}
void add(int from,int to){edge[++ cnt].to = to ; edge[cnt].next = head[from] ; head[from] = cnt;}
void dfs(int x,int f)
{
    S1.rt[x] = S1.clone();
    int tots = 0;int maxn[2] = {0 , 0};
    for(int e = head[x] ; e ; e = edge[e].next)
    {
        int to = edge[e].to;
        if(to == f) continue;
        dfs(to , x);
    }
    for(int e = head[x] ; e ; e = edge[e].next)
    {
        int to = edge[e].to;
        if(to == f) continue;
        tots ++;
        Ans = 0;
        if(val[x] == Cnt) fos[tots].x = 1;
        else fos[tots].x = S1.query(S1.rt[to] , 1 , Cnt , val[x] + 1 , Cnt , 0) + 1;
        if(val[x] == 1) fos[tots].y = 1;
        else fos[tots].y = S1.query(S1.rt[to] , 1 , Cnt , 1 , val[x] - 1 , 1) + 1;
        maxn[0] = max(maxn[0] , fos[tots].x);
        maxn[1] = max(maxn[1] , fos[tots].y);
        S1.rt[x] = S1.merge(S1.rt[x] , S1.rt[to] , 1 , Cnt);
        ans = max(ans , Ans);
    }
    if(!tots){ans = max(ans , 1);maxn[0] = maxn[1] = 1;}
    else if(tots == 1)
    {
        ans = max(ans , fos[tots].x);
        ans = max(ans , fos[tots].y);
    }
    else
    {
        sort(fos + 1 , fos + 1 + tots , cmpx);
        sort(fos + 1 , fos + tots , cmpy);
        ans = max(ans , fos[tots].x + fos[tots - 1].y - 1);
        sort(fos + 1 , fos + 1 + tots , cmpy);
        sort(fos + 1 , fos + tots , cmpx);
        ans = max(ans , fos[tots].y + fos[tots - 1].x - 1);
    }
    S1.update(S1.rt[x] , 1 , Cnt , val[x] , 0 , maxn[0]);
    S1.update(S1.rt[x] , 1 , Cnt , val[x] , 1 , maxn[1]);
}
int main()
{
    scanf("%d",&n);
    for(int i = 1 ; i <= n ; i ++) 
    {
        scanf("%d",&val[i]);
        lsh[++ Cnt] = val[i];
    }
    sort(lsh + 1 , lsh + 1 + Cnt);
    Cnt = unique(lsh + 1 , lsh + 1 + Cnt) - lsh - 1;
    for(int i = 1 ; i <= n ; i ++) val[i] = lower_bound(lsh + 1 , lsh + 1 + Cnt , val[i]) - lsh;
    for(int i = 1 ; i < n ; i ++)
    {
        int x,y;scanf("%d %d",&x,&y);
        add(x , y) , add(y , x);
    }
    dfs(1 , 0);
    printf("%d\n",ans);
    return 0;
}