题解 P3703 【[SDOI2017]树点涂色】
shadowice1984
2018-03-31 14:09:02
其实严格来讲这道题并不是lct?
(但是思路上还是lct的思路就对了,反正最后实现的时候一个splay就够了)
熟练压行的话百行以内可以写完?
(食用本题解时请确保您理解lct里的轻重边概念,但是并不需要熟练掌握lct)
~~(什么你不知道什么是轻重边?出门左转你站膜板区,包教包会)~~
# 本题题解
显然这是一道数据结构题,所以让我们先来分析一下资瓷什么操作
1.树上染色
2.树上询问颜色数量
3.子树询问内颜色最大值(显然子树内的所有点在子树过了子树根之后路径完全一样,所以可以只考虑到根的颜色个数)
那么我们会发现一个非常重要的性质,因为是到根节点路径染色,所以相同颜色的点必定在同一个联通块内,(也就是说不会有两段分开的颜色)
那么我们可以转化一下问题,路径上的颜色数目=路径上的分割边数+1
(这里指的分割边是指这个边连接的两个点颜色不同)
那么我们会发现其实如果把1操作看成lct的access操作,那么我们的“分割边”其实就是lct中的轻边
所以我们现在要资瓷的操作就是询问轻边数量?
那么看一下转换后的操作
1.access
2.询问路径上轻边数量
3.询问到根节点轻边数量最大
那么我们发现现在维护的东西是边而不是点,我们知道,如果是点的话在询问
u到v的路径信息时采用$ u+v-2lca(u,v)$的方式是会漏算lca这个点的,但是如果是边的话就可以直接使用$u+v-2lca(u,v)$这个式子了,因为各个路径的边是不相
交的
所以说我们想知道x到y的路径上有多少条轻边,只需要知道x到1有多少条轻边,y到1有多少条轻边,和x,y的lca有多少条轻边然后$u+v-2lca(u,v)$就行了
(下文中的”轻边数量“代指某个点到1路径上的轻边数量)
一开始这个树全部是轻边,所以每个点的轻边数量就等于这个节点的深度
我们发现当一个轻边变成重边的时候,以轻边底部的点为根的子树轻边数量-1
当一个重边变成轻边的时候,以重边底部的点为根的子树轻边数量+1
唯一的问题变成了如何在树上找到所有的轻边(咦这话我好像在哪里见过)
然后我们是可以lct暴力维护轻边的,根据lct的复杂度证明,我们最多修改nlogn条轻边,所以每次碰到一次轻重切换就在线段树上暴力修改即可,(子树问题可以通过dfs拍平转换为序列问题)
(或者链剖之后线段树/树状数组上二分没准也可以?)
由于我们只会进行nlogn次修改,所以总复杂度是$O(nlog^2n)$的
对于操作3是一个简单的区间最大值问题(因为我们维护的就是到根节点的轻边数量)
所以我们需要写一个资瓷区间加和查区间最大值的线段树(单点查值就是区间最大值233)和一个倍增找lca
然后下面简单介绍一下如何快速找到所有的轻边
显然我们在执行access操作的时候只会把链上的边变成重边
那么我们可以把在同一个重链里的点看成一个集合,在集合里按深度排序,我们只需要不停的集合的最小值,然后这个最小值和最小值的father就会构一个轻边了,当然我们还需要进行一些断开和合并集合的操作,比如轻边重合并两个集合,重变轻将一个集合分裂为两个集合
因此这需要我们实现一个**可分裂合并堆**,emm……也就是splay,(或者可以尝试下split-merge-treap?)然后就可以一路暴力跳轻边查找了
代码的话脑子清楚还是很好写的,但是如果逻辑混乱的话就非常不好写了
具体可以看注释加深理解
上代码~
```C
#include<cstdio>
#include<algorithm>
using namespace std;const int N=1e5+10;
int n;int m;int dfn[N];int nfd[N];int siz[N];int df;int f[N][22];
int v[2*N];int x[2*N];int al[N];int ct;int dep[N];int h[N];
struct splaytree//压行后的splay
{
int s[N][2];int fa[N];
inline int gc(int x){return s[fa[x]][1]==x;}//检查左右儿子
inline void rt(int x)//二合一旋转
{
int d=fa[x];int t=gc(x);s[d][t]=s[x][t^1];fa[s[x][t^1]]=d;
s[fa[d]][gc(d)]=x;fa[x]=fa[d];s[x][t^1]=d;fa[d]=x;
}
inline void rtup(int x){rt((gc(x)^gc(fa[x]))?x:fa[x]);rt(x);}//双旋
inline void splay(int x){for(;fa[fa[x]]&&fa[x];rtup(x));if(fa[x])rt(x);}//splay
inline int mi(int x){for(;s[x][0];x=s[x][0]);return x;}//最小值
inline int getmi(int x){splay(x);return mi(x);}//查找最小值
inline void split(int x){splay(x);fa[s[x][1]]=0;s[x][1]=0;}//分裂
inline void merge(int x,int y){splay(y);s[x][0]=y;fa[y]=x;}//合并
}spt;
struct linetree//线段树
{
int add[4*N];int ma[4*N];
inline void pushdown(int p)//pushdown
{add[p<<1]+=add[p];add[p<<1|1]+=add[p];ma[p<<1]+=add[p];ma[p<<1|1]+=add[p];add[p]=0;}
inline void setadd(int p,int l,int r,int dl,int dr,int plus)//区间加
{
if(dl==l&&dr==r){add[p]+=plus;ma[p]+=plus;return;}
int mid=(l+r)/2;pushdown(p);
if(dl<mid){setadd(p<<1,l,mid,dl,min(dr,mid),plus);}
if(mid<dr){setadd(p<<1|1,mid,r,max(dl,mid),dr,plus);}
ma[p]=max(ma[p<<1],ma[p<<1|1]);
}
inline int gmax(int p,int l,int r,int dl,int dr)//查最大值
{
if(dl==l&&dr==r){return ma[p];}
int mid=(l+r)/2;pushdown(p);int res=-0x3f3f3f3f;
if(dl<mid){res=max(res,gmax(p<<1,l,mid,dl,min(dr,mid)));}
if(mid<dr){res=max(res,gmax(p<<1|1,mid,r,max(dl,mid),dr));}
return res;
}
inline void build(int p,int l,int r)//建树
{
if(r-l==1){ma[p]=dep[dfn[r]];return;}int mid=(l+r)/2;
build(p<<1,l,mid);build(p<<1|1,mid,r);ma[p]=max(ma[p<<1],ma[p<<1|1]);
}
}lt;
inline void add(int u,int V){v[++ct]=V;x[ct]=al[u];al[u]=ct;}//加边
inline void dfs(int u)//dfs处理倍增lca和dfs序
{
siz[u]=1;dfn[++df]=u;nfd[u]=df;//倍增
for(int i=0;f[u][i];i++){f[u][i+1]=f[f[u][i]][i];}
for(int i=al[u];i;i=x[i])
{
if(siz[v[i]]){continue;}
f[v[i]][0]=u;dep[v[i]]=dep[u]+1;dfs(v[i]);siz[u]+=siz[v[i]];
}
}
inline int lca(int u,int v)//查lca
{
if(dep[u]<dep[v]){swap(u,v);}int d=dep[u]-dep[v];
for(int i=0;d;d>>=1,i++){if(d&1){u=f[u][i];}}if(u==v){return v;}
for(int i=20;i>=0;i--){if(f[u][i]!=f[v][i]){u=f[u][i];v=f[v][i];}}
return f[u][0];
}
inline void modify(int u)//修改
{
if(h[u]){lt.setadd(1,0,n,nfd[h[u]]-1,nfd[h[u]]+siz[h[u]]-1,1);spt.split(u);h[u]=0;}
for(u=spt.getmi(u);u!=1;u=spt.mi(f[u][0]))//注意先分裂待修改点
{
int& d=f[u][0];//然后先断开重儿子再合并新的重儿子,同时在线段树上修改
if(h[d]){lt.setadd(1,0,n,nfd[h[d]]-1,nfd[h[d]]+siz[h[d]]-1,1);spt.split(d);}
lt.setadd(1,0,n,nfd[u]-1,nfd[u]+siz[u]-1,-1);spt.merge(u,d);h[d]=u;
}
}
int main()
{
scanf("%d%d",&n,&m);int tp;int u;int v;
for(int i=1;i<n;i++){scanf("%d%d",&u,&v);add(u,v);add(v,u);}
dfs(1);lt.build(1,0,n);//预处理
for(int i=1;i<=m;i++)
{
scanf("%d%d",&tp,&u);if(tp==1){modify(u);}//修改
else if(tp==2)//查找点值然后减去
{
scanf("%d",&v);int l=lca(u,v);
int du=lt.gmax(1,0,n,nfd[u]-1,nfd[u]);
int dv=lt.gmax(1,0,n,nfd[v]-1,nfd[v]);
int dl=lt.gmax(1,0,n,nfd[l]-1,nfd[l]);
printf("%d\n",du+dv-2*dl+1);
}else {printf("%d\n",lt.gmax(1,0,n,nfd[u]-1,nfd[u]+siz[u]-1)+1);}//区间最大值
}return 0;//拜拜程序~
}
```