树链剖分 P2950树的统计(比模板还简单。。)

2018-05-20 20:36:29


哈哈哈哈哈哈哈,蒟蒻居然过了树剖,哈哈哈哈哈啊哈,嗝~

那么请开始我的树剖


首先上定义:

定义siz(x)为以x为根的子树的结点个数。

令v为u的儿子结点中siz()值最大的结点:

重结点:子树结点数目最多的结点

轻节点:父亲节点中除了重结点以外的结点

重边:父亲结点和重结点连成的边,即边(u,v)

轻边:重边之外的边

重链:由多条重边连接而成的路径

轻链:由多条轻边连接而成的路径

一个非叶节点有且仅有一个重儿子

加粗的链就是重链,注意4,7,8这三个点也是重链


那么再上一堆数组:

siz[i]以i为根结点的子树中结点的数目

hson[i]结点i的重儿子的编号

dep[i]结点i的深度,根的深度为1

top[i]结点i所在的重链的链首结点编号

fa[i]结点i的父结点编号

tid[i]结点i剖分以后的新编号(dfs2序号)

rnk[i]结点i在树中的原位置:i为新编号,rnk[i]为原编号


那么怎么求这些数组呢?

DFS1找重边 顺便求出siz[i],dep[i],fa[i],hson[i]

DFS2将重边连成重链 顺便求出top[i],tid[i],说的dfs2序号就是dfs序啦

橙色的是原序号,绿色的是dfs序,注意,我们先走重儿子,连成重链,让一个重链中的数dfs全部连在一起,这样就为我们后面的线段树维护做了铺垫

这样清楚一点


现在是线段树维护部分

我们已经把树处理成了线性,那么我们就可以用线段树维护了

如果要求i~j的最大值,如果i和j在同一个重链中,我们就直接输出max(tid[i]~tid[j]),而如果不在同一个重链中,我们就让深度大的那一个爬树,爬到它的top,求出其中的最大值,直到两个数在同一重链中,我们再求一下取个max就可以输出了

那为什么不可以直接线段树呢?当然是因为不同重链并不一定不连续

至于修改的话,正常线段树修改就行了


那么,代码时间

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int n,cnt=0;
int a[300100],f[300100];
int head[300100],nxt[300100],to[300100];
int siz[300100],d[300100],hs[300100];
int top[300100],ran[300100],tid[300100];
int sum[300100],ma[300100];
void addedge(int x,int y,int cn)
{
    nxt[cn]=head[x];
    head[x]=cn;
    to[cn]=y;
}
void dfs1(int u,int fa)
{
    f[u]=fa,siz[u]=1;//求出fa,size,dep,hson
    for(int i=head[u];i!=-1;i=nxt[i])
    {
        int v=to[i];
        if(v==fa) continue;
        d[v]=d[u]+1;
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>siz[hs[u]]) hs[u]=v;
    }
}
void dfs2(int u,int h)
{
    top[u]=h,tid[u]=++cnt,ran[cnt]=u;//求出top,tid,rank
    if(!hs[u]) return;//重边走到头return
    dfs2(hs[u],h);//先走重边
    for(int i=head[u];i!=-1;i=nxt[i])
    {
        int v=to[i];
        if(v==f[u] || v==hs[u]) continue;
        dfs2(v,v);
    }
}
void build(int l,int r,int x)//建树
{
    if(l==r)
    {
        sum[x]=ma[x]=a[ran[l]];
        return;
    }
    int mid=(l+r)/2;
    build(l,mid,x*2);
    build(mid+1,r,x*2+1);
    sum[x]=sum[x*2]+sum[x*2+1];
    ma[x]=max(ma[x*2],ma[x*2+1]);
}
void Change(int l,int r,int x,int u,int v)//线段树修改
{
    if(l==r)
    {
        ma[x]=sum[x]=v;
        return;
    }
    int mid=(l+r)/2;
    if(u<=mid) Change(l,mid,x*2,u,v);
    else Change(mid+1,r,x*2+1,u,v);
    sum[x]=sum[x*2]+sum[x*2+1];
    ma[x]=max(ma[x*2],ma[x*2+1]);
}
int Clacm(int l,int r,int x,int L,int R)
{
    if(L<=l && R>=r) return ma[x];
    int mid=(l+r)/2;
    if(R<=mid) return Clacm(l,mid,x*2,L,R);
    else if(L>mid) return Clacm(mid+1,r,x*2+1,L,R);
    else return max(Clacm(l,mid,x*2,L,R),Clacm(mid+1,r,x*2+1,L,R));
}
int Querym(int l,int r,int x,int u,int v)//max查询
{
    int ans=-1000000000;//注意权值有负数,赋个负无穷
    while(top[u]!=top[v])//不在同一重链中,爬树
    {
        if(d[top[u]]<d[top[v]]) swap(u,v);
        ans=max(ans,Clacm(1,n,1,tid[top[u]],tid[u]));//深度大的那个爬到top,同时取u~top[u]的最大值max一下
        u=f[top[u]];
    }
    if(d[u]<d[v]) swap(u,v);
    ans=max(ans,Clacm(1,n,1,tid[v],tid[u]));//同一重链直接线段树求max
    return ans;
}
int Clacs(int l,int r,int x,int L,int R)
{
    if(L<=l && R>=r) return sum[x];
    int mid=(l+r)/2;
    if(R<=mid) return Clacs(l,mid,x*2,L,R);
    else if(L>mid) return Clacs(mid+1,r,x*2+1,L,R);
    else return Clacs(l,mid,x*2,L,R)+Clacs(mid+1,r,x*2+1,L,R);
}
int Querys(int l,int r,int x,int u,int v)//sum同理
{
    int ans=0;
    while(top[u]!=top[v])
    {
        if(d[top[u]]<d[top[v]]) swap(u,v);
        ans+=Clacs(1,n,1,tid[top[u]],tid[u]);
        u=f[top[u]];
    }
    if(d[u]<d[v]) swap(u,v);
    ans+=Clacs(1,n,1,tid[v],tid[u]);
    return ans;
}
int main()
{
    memset(hs,0,sizeof(hs));
    memset(head,-1,sizeof(head));
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        addedge(x,y,i*2-1);
        addedge(y,x,i*2);
    }
    for(int i=1;i<=n;i++)
        scanf("%d",&a[i]);
    d[1]=1,dfs1(1,1),dfs2(1,1);
    build(1,n,1);
    char tmp[10];
    int m,x,y;
    scanf("%d",&m);
    while(m--)
    {
        scanf("%s%d%d",tmp,&x,&y);
        if(tmp[0]=='C') Change(1,n,1,tid[x],y);
        else if(tmp[1]=='M') printf("%d\n",Querym(1,n,1,x,y));
        else printf("%d\n",Querys(1,n,1,x,y));
    }
    return 0;
}