题解 CF1208H 【Red Blue Tree】

· · 题解

题意

给出一棵树和一个参数 k,叶子有蓝色和红色,对于一个非叶子节点 u,如果 \text{蓝色儿子}-\text{红色儿子}\ge k 那么这个节点为蓝色否则为红色。需要维护三个操作:

  1. 询问一个节点颜色
  2. 更改叶子颜色
  3. 更改 k

题解

虽然是 *3500 但是不是特别绕,只是代码有一点难写,蒟蒻写了 114514 分钟才写完。

考虑如果没有叶子颜色的变化,那么随着 k 的增加,感性理解一下发现一个节点变红更加容易。因此,对于一个节点 u,我们可以处理出刚刚好变成红色的时的 k_u,然后只需要用当前的 kk_u 比较大小就可以得到颜色了。

k_u 有一个比较 \rm Naive 的做法是二分,然后看 |son_u|-2\sum_{v\in son_u}[k_v\le k_u]<k_u 也就是 k_u+2\sum_{v\in son_u}[k_v\le k_u]>|son_u| 能否满足,但是求一次是 \mathcal O(n\log n) 的,并且更改之后不容易方便地维护。考虑开一棵权值线段树,然后再权值线段树上二分,就可以 \mathcal O(\log n) 的时间更改一个 k_v\mathcal O(\log n) 计算现在的 k_u

然后再加上更改操作,看上去感觉很像 \rm DDP,那么我们先树链剖分再说。于是一个点有了轻儿子和重儿子,考虑如何将信息合并。首先我们不考虑重儿子得到一个 k,然后考虑重儿子的权值是 x 想要加入,此时 \rm{LHS}-\rm{RHS}=\Delta>0\sum_{v\in lightson_u}[k_v=k]=cnt_0,情况非常有限:

于是这个函数只有三段区间,我们把这个转移写成 k'_u=f_u(heavyson_u),这个函数可以比较方便地合并,于是写一棵平衡树或线段树就可以较快地转移了。函数 f_u(x) 就直接在权值线段树上模拟上面的过程就好了。

发现更改轻儿子的复杂度是 \mathcal O(\log n) 的,每一次最多更改 \mathcal O(\log n) 个轻儿子,一条重链更改平衡树的复杂度是 \mathcal O(\log n) 的,最多更改 \mathcal O(\log n) 条,所以复杂度是 \mathcal O(q\log^2n) 的。

代码

#include<bits/stdc++.h>
#define mp make_pair
#define mt make_tuple
#define pb push_back
#define pc putchar
#define chkmx(a,b) (a)=max((a),(b))
#define chkmn(a,b) (a)=min((a),(b))
#define fi first
#define se second
using namespace std;
template<class T>
void read(T&x){x=0;char c=getchar();bool f=0;for(;!isdigit(c);c=getchar())f^=c=='-';for(;isdigit(c);c=getchar())x=x*10+c-'0';if(f)x=-x;}
template<class T,class ...ARK>void read(T&x,ARK&...ark){read(x);read(ark...);}
template<class T>void write(T x){if(x<0)pc('-'),x=-x;if(x>=10)write(x/10);pc(x%10+'0');}
template<class T,class ...ARK>void write(T x,ARK...ark){write(x);pc(' ');write(ark...);}
template<class ...ARK>void writeln(ARK...ark){write(ark...);pc('\n');}
typedef long long ll;
#define lowbit(x) ((x)&-(x))
#define mid ((l+r)>>1)
const int N=1e5+100;
int n,q,nowk,col[N];//0=red 1=blue
vector<int>e[N];
struct dtkdSGT{
    /*
    动态开点线段树
    支持: 
    (1)单点加减
    (2)区间查询
    (2)查询最小的k满足 k + 2 sum [k_i<=k] > m   
    */
    int cnt,sum[N<<6],lc[N<<6],rc[N<<6];
    void pushup(int x){sum[x]=sum[lc[x]]+sum[rc[x]];}
    void upd(int&x,int l,int r,int pos,int val){
        if(!x)x=++cnt;
        if(l==r)return sum[x]+=val,void();
        if(pos<=mid)upd(lc[x],l,mid,pos,val);
        else upd(rc[x],mid+1,r,pos,val);
        pushup(x);
    }
    int qry(int x,int l,int r,int pre,int m){
        if(l==r)return l;
        if(mid+2*(pre+sum[lc[x]])>m)return qry(lc[x],l,mid,pre,m);
        else return qry(rc[x],mid+1,r,pre+sum[lc[x]],m);
    }
    int ct(int x,int l,int r,int ql,int qr){
        if(!x)return 0;
        if(ql<=l&&r<=qr)return sum[x];
        if(r<ql||qr<l)return 0;
        return ct(lc[x],l,mid,ql,qr)+ct(rc[x],mid+1,r,ql,qr);
    }
}T;
int fa[N],son[N],sz[N],top[N],dep[N],lcnt[N],lrt[N];
void dfs1(int u){
    dep[u]=dep[fa[u]]+1;sz[u]=1;
    for(auto v:e[u])if(v!=fa[u])
        fa[v]=u,dfs1(v),sz[u]+=sz[v],lcnt[u]++;
    if(lcnt[u])lcnt[u]--;
}
void dfs2(int u){
    if(!top[u])top[u]=u;
    if(~col[u])return;
    pair<int,int>mx=mp(0,0);
    for(auto v:e[u])if(v!=fa[u])chkmx(mx,mp(sz[v],v));
    top[son[u]=mx.se]=top[u];
    for(auto v:e[u])if(v!=fa[u])dfs2(v);
}

int root[N],pa[N],ch[N][2];//二叉树
bool isroot(int x){return ch[pa[x]][0]!=x&&ch[pa[x]][1]!=x;}
struct node{
    //维护的东西
    int k,a,b,c;
    // x>k a   x=k b   x<k-1 c
    node(int x=0){k=a=b=c=x;}
    int operator()(int x)const{
        if(x>k)return a;
        if(x==k)return b;
        if(x<k)return c;
    }
    friend node operator+(const node&f,const node&g){
        //返回 f(g(x))
        node res=g;
        res.a=f(g.a);res.b=f(g.b);res.c=f(g.c);
        return res;
    }
}f[N],mul[N];
//f是这个点,mul是这段区间复合
void pushup(int x){
    mul[x]=f[x];
    if(ch[x][1])mul[x]=mul[x]+mul[ch[x][1]];
    if(ch[x][0])mul[x]=mul[ch[x][0]]+mul[x];
}
void upd(int x){
    //计算 f
    int k=T.qry(lrt[x],-n-1,n+1,0,lcnt[x]);
    int delta=k+2*T.ct(lrt[x],-n-1,n+1,-n-1,k)-lcnt[x];
    assert(delta>0);
    int cnt0=T.ct(lrt[x],-n-1,n+1,k,k);
    f[x].k=k;
    if(delta==1)f[x].a=k+1;else f[x].a=k;
    f[x].b=k;
    if(delta-1-2*cnt0+2>1)f[x].c=k-1;
    else f[x].c=k;
}
int build(int l,int r,vector<int>&arr){
    if(l>r)return 0;
    int x=arr[mid];
    ch[x][0]=build(l,mid-1,arr);if(ch[x][0])pa[ch[x][0]]=x;
    ch[x][1]=build(mid+1,r,arr);if(ch[x][1])pa[ch[x][1]]=x;
    pushup(x);
    return x;
}
void dfs3(int u){
    //处理重链
    vector<int>now;
    for(int x=u;x;x=son[x])now.pb(x);
    for(auto x:now)
        for(auto y:e[x])if(y!=fa[x]&&y!=son[x])
            dfs3(y),T.upd(lrt[x],-n-1,n+1,mul[root[y]](0),1),pa[root[y]]=x;
    for(auto x:now){
        if(son[x])upd(x);
        else{
            assert(sz[x]==1);
            if(col[x]==0)f[x]=node(-n-1);
            else f[x]=node(n+1);
        }
    }
    root[u]=build(0,now.size()-1,now);
}
void mdf(int x){
    //现在改了x的颜色
    if(col[x]==0)f[x]=node(-n-1);
    else f[x]=node(n+1);
    for(;x;x=pa[x])
        if(isroot(x)){
            T.upd(lrt[pa[x]],-n-1,n+1,mul[x](0),-1);
            pushup(x);
            T.upd(lrt[pa[x]],-n-1,n+1,mul[x](0),1);
            upd(pa[x]);
        }else pushup(x);
}
node qry(int x){
    //询问这一条重链
    node res=f[x];
    if(ch[x][1])res=res+mul[ch[x][1]];
    while(!isroot(x)){
        if(x==ch[pa[x]][0]){
            x=pa[x];res=res+f[x];
            if(ch[x][1])res=res+mul[ch[x][1]];
        }
        else x=pa[x];
    }
    return res;
}

signed main(){
    read(n,nowk);
    for(int i=1,u,v;i<n;i++)read(u,v),e[u].pb(v),e[v].pb(u);
    for(int i=1;i<=n;i++)read(col[i]);
    dfs1(1);
    dfs2(1);
    dfs3(1);
    read(q);while(q--){
        int op;read(op);
        if(op==1){
            int v;read(v);node res=qry(v);
            assert(res.a==res.b&&res.b==res.c);
            write((int)(nowk<res(0)));pc('\n');
        }else if(op==2){
            int v;read(v);read(col[v]);
            mdf(v);
        }else read(nowk);
    }
}