题解 CF490F 【Treeland Tour】

· · 题解

模拟赛做到这题的加强版(n\leq 2\cdot 10^5),然而我一直没想到另维护一个下降子序列,结果最后只写出一个 n^2 换根 \kk

写完才发现文章有点长,因此有许多细节部分可能因为排版混乱(不太会排长文章...)导致表意不清,还请见谅qaq

有疑惑可以在评论区问我

解析

这题的做法好像还蛮多的...这里就把目前我知道的做法都讲一遍把X

换根做法(O(n^2)

首先任意规定一个根做树 dp,计算出以每个结点为子序列结尾,子序列开头在其子树内的最长上升子序列长度。转移可以暴力转,也可以用各种方式优化

此时根的 dp 值就是以根为子序列结尾,能得到的最长上升子序列长度

接着考虑换根。可以每次直接暴力转移受影响的两个结点的 dp 值。这个转移我想不到什么能优化的方法(也成为了这种做法的复杂度瓶颈 \kk)

对每次换根得到的 dp 值取 \max 即可

一个 trick

(其实也不算是 trick,但我想不出更合适的小标题名字...)

在换根 dp 做法中,我们始终默认当前结点是整个上升子序列的结尾,从而导致最后非换根不可。因为这样我们无法统计当前结点为序列中间点时的答案

于是思考有没有办法维护当前结点为序列中间点时的答案。可以发现,这实际上就是把一个以当前结点为结尾的上升子序列和下降子序列 “拼” 起来(合并)

不过这样还是有些问题,我们没法统计到当前结点为中间点,子序列一部分在父亲处的答案。实际上,在每个结点处,我们只需统计两部分分别在该结点两个儿子子树中的子序列就行了

于是只需另维护一个下降子序列,在每个结点处均两两合并统计一次答案即可

合并的具体方式

上一节提到的 “把一个以当前结点为结尾的上升子序列和下降子序列 ‘拼’ 起来” 的具体实现其实还有待商讨

(首先在合并每个子树的线段树/数组前,我们都默认先将当前结点插入线段树/数组,以方便讨论。因为求的是严格上升子序列,因此这样不会对答案造成影响(即使求的不是严格上升子序列,也可以考虑在统计合并的答案后,还原新的线段树/数组,再将新线段树/数组合并))

可以发现,如果在将所有儿子子树贡献都合并后(因具体做法而异)再找一条最长上升子序列和最长下降子序列拼在一起,它们之间可能会有共用的元素。具体来说,不应当组合来自同一个儿子子树的最长上升子序列和最长下降子序列

形式化地讲,我们有 m 个集合 A_1, A_2, \cdots, A_m,最后答案 (a, b)a, b 之间无序) 必须满足 a\in A_i, b\in A_j, i\not=j

f(A, B) 表示统计所有 (a, b) 满足 a\in A, b\in B。如果暴力地枚举带入 A_i,则共需要计算 m^2f(A, B)(况且每次计算的复杂度也不低)。我们考虑一个方法:

设当前已合并的集合 S(初始为空),新加入的集合 A_i,我们计算一次 f(S, A_i),然后再将 A 并入 S,如此重复。可以证明这样和刚才的方式是等价的

考虑一个元素 a\in A_i。在 A_i 被并入 S 时,f 统计了所有可能的 (a, b) 满足 a\in A_i, b\in A_j, j<i;在 A_i 被并入 S 后,f 统计了所有可能的 (a, b) 满足 a\in A_i, b\in A_j, j>i。由此可证得每种合法的 (a, b) 都被统计了

### dsu on tree I($O(n\log^2 n)$) 不带修改的全局子树贡献统计——很自然地就能想到 dsu on tree 如果按 “换根做法” 那一节的 dp 方式,需要用线段树优化转移(权值线段树,动态开点。动态开点按本题数据范围可能不需要) 具体来说,我们开 $\log n$ 个线段树;若某个结点到根经过了 $i$ 条轻边,就使用第 $i+1$ 个线段树。dfs 到 $u$ 时,先继承重儿子的线段树,将 $u$ 插入该线段树,再暴力地将所有轻儿子子树内结点插入该线段树 由于 dfs 同一时刻只会处理一条到根经过了 $i$ 条轻边的重链,因此线段树的占用不会冲突。但是记得在**叶子结点**重置线段树 每次继承重儿子线段树并插入的复杂度是 $O(\log n)$ 的;每个结点**每次**作为轻儿子子树内结点被暴力统计的复杂度是 $O(\log n)$ 的。因此这部分的总复杂度为 $O(n\log n+n\log^2 n)

 

最后考虑如何快速合并最长上升子序列和最长下降子序列

考虑在合并轻儿子子树前先统计每个轻儿子子树内结点作为答案链的中间一点的情况

具体来说,若假设 x 作为答案链的中间一点,那么就在继承的重儿子线段树中查找出适合的上升/下降链拼上去,并与答案比较

每个轻儿子子树内结点每次被统计贡献的复杂度是 \log n 的。最终复杂度就为 O(n\log n+n\log^2 n+n\log^2 n)

dsu on tree II(O(n\log^2 n)

同样是 dsu on tree,还可以考虑用维护辅助数组的 dp 方法

g(k) 表示长度为 k 的子序列末端最小/最大的元素。可以发现 g(k) 的值随 k 的增加是单调的,于是就可以二分转移

具体来说,我们仍旧开 \log ng 数组。dfs 到 u 时,直接继承重儿子的 g 数组,二分得到 dp[u],然后按下标暴力合并轻儿子的 g 数组(注意每 dfs 完一个轻儿子就要合并,不然数组使用会冲突)

每次计算 dp[.] 的复杂度是 O(\log n) 的;g 数组最坏大小为子树大小,因此我们就粗略地将合并 g 数组视为暴力统计子树内的每个结点,且统计单个结点的复杂度为 O(1),因此每个结点对暴力统计的复杂度做出的贡献就可以视为是 O(\log n) 的(实际上并跑不满)。于是这部分的总复杂度就为 O(n\log n+n\log n)

 

最后考虑如何快速合并

在合并轻儿子子树前,我们仍旧考虑先统计轻儿子子树内结点作为答案链的中间一点的情况,但这个结点必须在该子树的 g[.] 中出现过(实现时只遍历轻儿子的 g 数组就行了)

具体来说,若假设 x 作为答案链的中间一点,那么就在继承的重儿子 g[.] 中二分查找出适合的上升/下降链拼上去,并与答案比较

每个轻儿子子树内结点每次被统计贡献的复杂度是 \log n 的。最终复杂度就为 O(n\log n+n\log n+n\log^2 n)

线段树合并(O(n\log n)

思考 “dsu on tree” 中的第一种实现方法,可以想到每次不一个个暴力插入轻儿子子树内的结点,而是直接合并轻儿子的线段树。接着还发现,这时候复杂度已经和 dsu on tree 没什么关系了,每次也不需要一定先继承重儿子的线段树,继承任意一个线段树就可以了

每次继承线段树并插入的复杂度是 O(\log n) 的;线段树合并的总复杂度是 O(n\log n) 的。因此这部分的总复杂度为 O(n\log n+n\log n)

 

考虑如何统计合并的答案

看起来我们只能暴力遍历新线段树中的每个元素(叶节点)并为它们在已合并的线段树找到合适的链 “拼” 上去,但实际上我们可以直接在线段树合并时统计答案

首先维护的最长上升和最长下降子序列的信息可以放到一个线段树中;

具体来说,设我们正在合并的两个线段树结点 x1, x2,我们只考虑统计跨 mid=\frac {l_x+r_x} 2 的答案,那么显然只需在 [l, mid] 中找出一条最长的 最长上升/下降子序列 与 [mid+1, r] 中找出的一条最长的 最长上升/下降子序列 拼在一起。这样找出的两条拼在一起一定是合法的,因为线段树的下标即为结尾元素的值。而这个操作显然可以在合并 x1, x2 的递归函数中 O(1) 完成

最后统计出来的答案也是不遗漏的

也可参考这段代码:

void merge(int &x1, int x2){
    /*...*/
    int tmp =max(data_LIS[ls[x1]]+data_LDS[rs[x2]], data_LIS[rs[x1]]+data_LDS[ls[x2]]);
    Ans =max(Ans, tmp);
    merge(ls[x1], ls[x2]), merge(rs[x1], rs[x2]);
}

这样,答案统计的复杂度就被均摊到线段树合并中了。最终复杂度就为 O(n\log n+n\log n+n\log n)

长剖(O(n\log n)

思考 “dsu on tree” 中的第二种实现方法,我们的 g[.] 的下标上限实际上是和当前链长(深度)有关的,并且计算 dp[.] 的复杂度和 dsu on tree 部分无关,因此可以考虑用长剖优化

具体实现直接套用长剖维护 g[.] 的合并就行了,如对长剖不熟悉可以参考后面的代码

每次计算 dp[.] 的复杂度是 O(\log n) 的;计算 g[.] 的总复杂度是 O(n) 的。因此这部分的总复杂度为 O(n\log n+n)

 

考虑如何统计合并的答案

在合并轻儿子子树前,我们仍旧考虑先统计轻儿子子树内结点作为答案链的中间一点的情况。和 dsu on tree II 一样,这个结点必须在该子树的 g[.] 中出现过(实现时只遍历轻儿子的 g 数组就行了)

具体来说,若假设 x 作为答案链的中间一点,那么就在继承的重儿子 g[.] 中二分查找出适合的上升/下降链拼上去,并与答案比较

(上面两段也基本都是 copy 的X)

每条链会且仅会被合并一次。即使 g 数组都是最坏情况,即下标为每条链的链长,每条链在被合并、统计答案时,统计答案部分也只会做出 O(\log n) 乘链长(g 数组下标)的复杂度。于是总的复杂度就为 O(n\log n+n+n\log n)

 

这种做法理论上应该是跑得最快的了,并且复杂度瓶颈是在 lower_bound(二分查找)。不过我的实现貌似并不优秀,并且为了方便直接用 vector 实现了长剖,因此常数可能会有点大 \kk

CODE

这里只实现了两种

换根做法

注意每次从 dfs 返回时要还原 dp 值,否则后面暴力转移时会出错

#include <cstdio>
#include <vector>
using std::vector;
using std::max;

const int MAXN =6e3+50;

/*------------------------------Map------------------------------*/

vector<int> E[MAXN];

inline void addedge(int u, int v){
    E[u].push_back(v), E[v].push_back(u);
}

/*------------------------------Dfs------------------------------*/

int Ans =0;
int a[MAXN];
int dp[MAXN];

int dfs_find(int u, int fa, const int &val){
    int ret =(a[u] < val) ? dp[u] : 0;
    for(int v:E[u])
        if(v != fa)
            ret =max(ret, dfs_find(v, u, val));
    return ret;
}

void dfs1(int u, int fa){
    for(int v:E[u])
        if(v != fa)
            dfs1(v, u);
    dp[u] =0;
    dp[u] =dfs_find(u, fa, a[u])+1;
    Ans =max(Ans, dp[u]);
}

void dfs2(int u, int fa){
    for(int v:E[u])
        if(v != fa){
            int tmp1 =dp[u];
            dp[u] =0;
            dp[u] =dfs_find(u, v, a[u])+1;
        //  Ans =max(Ans, dp[u]);/*一定不更优*/

            int tmp2 =dp[v];
            dp[v] =0;
            dp[v] =dfs_find(v, 0, a[v])+1;
            Ans =max(Ans, dp[v]);
            dfs2(v, u);

            dp[u] =tmp1;
            dp[v] =tmp2;/*<- 需要还原*/
        }
}

/*------------------------------Main------------------------------*/

inline int read(){
    int x =0; char c =getchar(); bool f =0;
    while(c < '0' || c > '9') (c == '-') ? f =1, c =getchar() : c =getchar();
    while(c >= '0' && c <= '9') x = (x<<3) + (x<<1) + (48^c), c =getchar();
    return (f) ? -x : x;
}

int main(){
    int n =read();
    for(int i =1; i <= n; ++i)
        a[i] =read();
    for(int i =0; i < n-1; ++i)
        addedge(read(), read());
    dfs1(1, 0);
    dfs2(1, 0);
    printf("%d", Ans);
}

长剖做法

这里为了方便,直接通过将元素取反、维护最长上升子序列的方式来维护最长下降子序列

写得可能不是很标准...仅供参考

#include <cstdio>
#include <vector>
#include <algorithm>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast", "-funroll-loops", "-fdelete-null-pointer-checks")
#pragma GCC target("ssse3", "sse3", "sse2", "sse", "avx2", "avx")
using std::vector;
using std::max;
using std::min;
using std::lower_bound;

const int MAXN =6e3+50;

/*------------------------------Map------------------------------*/

vector<int> E[MAXN];

inline void addedge(int u, int v){
    E[u].push_back(v), E[v].push_back(u);
}

/*------------------------------Dfs------------------------------*/

int Ans =1;
int a[MAXN];

inline void add(vector<int> &v, const int &ai){/*更新 LIS/LDS*/
    int len =lower_bound(v.begin(), v.end(), ai)-v.begin()-1;
    if(len+1 >= (int)v.size())
        v.push_back(ai);
    else
        v[len+1] =min(v[len+1], ai);
}

void updata_ans(const vector<int> &v1, const vector<int> &v2){
    for(int len_v1 =0; len_v1 < (int)v1.size(); ++len_v1)
        Ans =max(Ans, len_v1 + (int)(lower_bound(v2.begin(), v2.end(), -v1[len_v1])-v2.begin()-1));
}

inline void vector_free_memory(vector<int> &v){/*释放 vector 内存 ( 实际上在本题没啥必要 X )*/
    vector<int> tmp;
    v.swap(tmp);
}

void merge(vector<int> &v1, vector<int> &v2){/*合并 v2 到 v1*/
    for(int i =0; i < (int)v2.size(); ++i){
        /*注意 g[.] 的下标上限不一定就等于链长*/
        /*因此即使继承了长儿子的 g[.],也要考虑下标越界的情况*/
        if(i >= (int)v1.size())
            v1.push_back(v2[i]);
        else
            v1[i] =min(v1[i], v2[i]);
    }
    vector_free_memory(v2);
}

/*下面这些即为题解中所述的 "g[.]"*/
vector<int> tail_LIS[MAXN];/*长度为 i 的最小结尾元素;可知随 i 减小 tail[i] 单调不增*/
vector<int> tail_LDS[MAXN];/*同上,这里为了方便直接将元素都取反了*/

int dfs(int u, int fa){
    int mxlen =0, mxi =-1;
    for(int v:E[u])
        if(v != fa){
            int tmp =dfs(v, u);
            if(tmp > mxlen)
                mxlen =tmp, mxi =v;
        }
    tail_LIS[u].push_back(-0x7f7f7f7f), tail_LIS[u].push_back(a[u]);
    tail_LDS[u].push_back(-0x7f7f7f7f), tail_LDS[u].push_back(-a[u]);
    if(mxi != -1){
        int v =mxi;
        add(tail_LIS[v], a[u]), add(tail_LDS[v], -a[u]);
        updata_ans(tail_LIS[u], tail_LDS[v]);
        updata_ans(tail_LDS[u], tail_LIS[v]);
        tail_LIS[u].swap(tail_LIS[v]);
        tail_LDS[u].swap(tail_LDS[v]);
    }
    for(int v:E[u])
        if(v != fa && v != mxi){
            add(tail_LIS[v], a[u]), add(tail_LDS[v], -a[u]);
        /*1. 更新答案 - 有用到一个小 trick*/
        /*2. 虽然两个集合都可能有以 a[u] 结尾的链,但它们并不会连在一起*/
        /*2. 如果 < 被改成 <=,也可做:*/
        /*2. 具体来说,每次给 v 的答案加上 a[i],与已遍历的 u 答案合并,再还原 v 的答案并加入已遍历的 u 答案即可 V*/
            updata_ans(tail_LIS[v], tail_LDS[u]);
            updata_ans(tail_LDS[v], tail_LIS[u]);
        //  modify(tail_LIS[v]), modify(tail_LDS[v]);/*<- 就像这样 2.*/ 
            merge(tail_LIS[u], tail_LIS[v]);
            merge(tail_LDS[u], tail_LDS[v]);
        }
    return mxlen+1;
}

/*------------------------------Main------------------------------*/

inline int read(){
    int x =0; char c =getchar(); bool f =0;
    while(c < '0' || c > '9') (c == '-') ? f =1, c =getchar() : c =getchar();
    while(c >= '0' && c <= '9') x = (x<<3) + (x<<1) + (48^c), c =getchar();
    return (f) ? -x : x;
}

int main(){
    int n =read();
    for(int i =1; i <= n; ++i)
        a[i] =read();
    for(int i =0; i < n-1; ++i)
        addedge(read(), read());
    dfs(1, 0);
    printf("%d", Ans);
}