题解 P3703 【[SDOI2017]树点涂色】

· · 题解

一道非常妙的 LCTQwQ,感觉做了之后会对 access 的理解深一些 QwQ

(树剖大佬别打我 QAQ )

首先 操作1 具有一个特性,染色操作全部都是直接染到顶点。

让人想到 LCT 的 access 操作。

发现每次都是开一个新的颜色,我们可以知道:一个点到根节点有多少种颜色,体现在 LCT 上即:从此点走到根节点需要走的虚链个数 + 1

初始时每个点都是一个新颜色,即 LCT 中的所有边都是虚边。

考虑记dis[x]表示从根节点走到此点需要经过多少条虚链+1,那么最初的 dis[x]dep[x]

相对而言较为好处理的是 2 操作,考虑树上差分,倍增求 LCA 不难发现答案为:dis[x] + dis[y] - 2 * dis[lca] + 1

接下来考虑操作3 操作3 需要维护的是子树信息。

所以需要用一个数据结构来维护 dis数组 -> 线段树。

如果对原树按照 dfs 进行遍历,在 dfs 序中每棵子树对应的区间都是连续的。

那么操作3就变成了在线段树中询问区间最大 dis 值。

现在我们来考虑最恶心的 操作1

这是一个 access

我们看看一般的access是怎么写的:

void access( int x ) {
    for( int y = 0; x; y = x, x = t[y].fa )
        Splay(x), t[x].son[1] = y, pushup(x);
}

如果将这段代码分得更细一点,其实是:

$2.$将 $x$ 的右儿子变虚 $3.$将 $x$ 的虚儿子$(y)$ 变实。 然后$pushup

我们发现这一过程中其实发生了:1.一条虚链变成了实链,2.一条实链变成了虚链。

考虑它原来的儿子,这条实链变成了虚链,那么它的儿子及其下面所有的点往根节点走的时候都要多经过一条虚链,换而言之:此儿子所管辖的子树区间的 dis 值全部 +1

然后另一个操作为:一个虚儿子变成了实儿子,那么此儿子变成了实儿子,那么其的儿子走到根节点的路上都会少走一条虚链,等价于:此儿子所管辖的所有子树区间的 dis 值全部 -1

注意坑点:比如这样写就是错的(伪代码):

void access( int x ) {

    for( int y = 0; x; y = x, x = t[y].fa ) {
        Splay(x), update( rs(x), + 1 ); //更新右儿子。
        update( y, -1 ), t[x].son[1] = y;
    }
}

需要注意:1.LCT 中的Splay作为辅助树,其维护的是一个中序遍历,并不是真实树中的一种父子关系,然而我们做的操作是对其真实的儿子所管辖的区间做的+1/-1 所以我们要找到其原本的儿子。

所以我们对于原树,我们需要找到其右子树(深度比其大的点)中深度最小的点。

对于接上去的点,我们也需要找到其中深度最小的点。

故其实还要一个 findroot 函数(某一颗子树中深度最小的点)

具体实现看代码中access

代码:

#include<bits/stdc++.h>
using namespace std;
int read() {
    char cc = getchar(); int cn = 0, flus = 1;
    while(cc < '0' || cc > '9') {  if( cc == '-' ) flus = -flus;  cc = getchar();  }
    while(cc >= '0' && cc <= '9')  cn = cn * 10 + cc - '0', cc = getchar();
    return cn * flus;
}
#define rep( i, s, t ) for( register int i = s; i <= t; ++ i )
#define drep( i, t, s ) for( register int i = t; i >= s; -- i )
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define Max( x, y ) ( ( (x) > (y) ) ? (x) : (y) )
const int N = 1e5 + 5;
int L[N], R[N], rev[N], id[N], head[N], n, m, cnt, tot ;
int fath[N][21], dep[N];
struct E {
    int to, next ;
} e[N * 2];
namespace Tree {
    #define ls(x) ( x * 2 )
    #define rs(x) ( x * 2 + 1 )
    struct Tr {
        int mark, mx;
    }tr[N * 4];
    void pushup( int x ) {
        if( tr[x].mark ) {
            tr[ls(x)].mark += tr[x].mark, tr[ls(x)].mx += tr[x].mark ;
            tr[rs(x)].mark += tr[x].mark, tr[rs(x)].mx += tr[x].mark ;
            tr[x].mark = 0;
        }
    }
    void build( int x, int l, int r ) {
        if( l == r ) { tr[x].mx = dep[rev[l]]; return ; }
        int mid = ( l + r ) >> 1;
        build( ls(x), l, mid ), build( rs(x), mid + 1, r );
        tr[x].mx = Max( tr[ls(x)].mx, tr[rs(x)].mx );
    }
    int query( int x, int l, int r, int ll, int rr ) {
        if( l > rr || r < ll ) return 0;
        if( l >= ll && r <= rr ) return tr[x].mx;
        pushup(x);
        int mid = ( l + r ) >> 1;
        return max( query( ls(x), l, mid, ll, rr ), query( rs(x), mid + 1, r, ll, rr ) );
    }
    void update( int x, int l, int r, int ll, int rr, int add ) {
        if( l > rr || r < ll ) return ;
        if( l >= ll && r <= rr ) {
            tr[x].mark += add, tr[x].mx = tr[x].mx + add;
            return ;
        }
        int mid = ( l + r ) >> 1;
        pushup(x);
        update( ls(x), l, mid, ll, rr, add ), update( rs(x), mid + 1, r, ll, rr, add );
        tr[x].mx = max( tr[ls(x)].mx, tr[rs(x)].mx );
    }
}
namespace LCT1 {
    #define ls(x) t[x].son[0]
    #define rs(x) t[x].son[1]
    struct LCT {
        int son[2], fa, val;
        bool mark ;
    } t[N * 2];
    int isroot( int x ) {
        return ( ls(t[x].fa) != x ) && ( rs(t[x].fa) != x ) ;
    }
    void pushmark( int x ) {
        if( t[x].mark ) {
            swap( ls(x), rs(x) ); 
            t[x].mark = 0, t[ls(x)].mark ^= 1, t[rs(x)].mark ^= 1;
        }
    }
    void rotate( int x ) {
        int f = t[x].fa, ff = t[f].fa, qwq = ( rs(t[x].fa) == x );
        t[x].fa = ff;
        if( !isroot(f) ) t[ff].son[rs(ff) == f] = x ;
        t[f].son[qwq] = t[x].son[qwq ^ 1], t[t[x].son[qwq ^ 1]].fa = f ;
        t[x].son[qwq ^ 1] = f, t[f].fa = x ;
    }
    int st[N];
    void Splay( int x ) {
        int top = 0, now = x; st[++ top] = now ;
        while( !isroot(now) ) st[++ top] = ( now = t[now].fa );
        while( top ) pushmark( st[top --] );
        while( !isroot(x) ) {
            int f = t[x].fa, ff = t[f].fa;
            if( !isroot(f) ) ( ( rs(f) == x ) ^ ( rs(ff) == f ) ) ? rotate(x) : rotate(f) ;
            rotate(x);
        }
    }
    int findrt( int x ) { //要找到深度最小的点。 
        while( ls(x) ) x = ls(x) ;
        return x;
    }
    void access( int x ) { //access 
        int son;
        for( int y = 0; x; y = x, x = t[y].fa ) {
            Splay(x);
            if( rs(x) ) son = findrt(rs(x)), Tree :: update( 1, 1, n, L[son], R[son], 1 ); //如果儿子存在才update 
            if( rs(x) = y ) son = findrt(y), Tree :: update( 1, 1, n, L[son], R[son], -1 );
            t[x].son[1] = y ; 
        }
    }
}
void add( int x, int y ) {
    e[++ cnt] = (E){ y, head[x] }, head[x] = cnt ;
    e[++ cnt] = (E){ x, head[y] }, head[y] = cnt ; 
}
void dfs( int x, int fa ) {
    L[x] = ++ tot, rev[tot] = x; // L[x],R[x]表示x管辖的区间,rev表示线段树中的位置对应点 
    LCT1::t[x].fa = fath[x][0] = fa, dep[x] = dep[fa] + 1 ;
    rep( i, 1, 19 ) fath[x][i] = fath[fath[x][i - 1]][i - 1]; //倍增处理 
    Next( i, x ) {
        int v = e[i].to ;
        if( v == fa ) continue ;
        dfs( v, x );
    }
    R[x] = tot ;
}
int LCA( int x, int y ) { //LCA
    if( dep[x] < dep[y] ) swap( x, y );
    drep( i, 19, 0 ) if( dep[fath[x][i]] >= dep[y] ) x = fath[x][i] ;
    drep( i, 19, 0 ) if( fath[x][i] != fath[y][i] ) x = fath[x][i], y = fath[y][i] ;
    return ( x == y ) ? x : fath[x][0] ;
}
signed main()
{
    n = read(), m = read() ;
    int x, y, opt, ans ;
    rep( i, 1, n - 1 ) x = read(), y = read(), add( x, y );
    dfs( 1, 1 ); LCT1 :: t[1].fa = 0; 
    Tree :: build( 1, 1, n ); //初始建树 

    rep( i, 1, m ) {
        opt = read(), x = read();
        if( opt == 1 ) LCT1 :: access(x);
        if( opt == 2 ) {
            y = read();
            int lca = LCA( x, y );
            int ans1 = Tree :: query( 1, 1, n, L[x], L[x] ) ;
            int ans2 = Tree :: query( 1, 1, n, L[y], L[y] ) ;
            int ans3 = ( Tree :: query( 1, 1, n, L[lca], L[lca] ) );
            printf("%d\n", ans1 + ans2 - 2 * ans3 + 1 );
        }
        if( opt == 3 ) {
            ans = Tree :: query( 1, 1, n, L[x], R[x] );
            printf("%d\n", ans );
        }   
    }
    return 0;
}