题解 P4219 【[BJOI2014]大融合】

· · 题解

题目大意

给定 n 个结点和 q 次操作,每个操作为如下形式:

A x y 在结点 xy 之间连接一条边。

Q x y 给定一条已经存在的边 (x,y) ,求有多少条简单路径,其中包含边 (x,y)

保证在任意时刻,图的形态都是一棵森林。

1\le n,q,x,y\le 10^5

解题思路

为询问 Q 考虑另一种表述,我们发现答案等于边 (x,y)x 侧的结点数与 y 侧的结点数的乘积,即将边 (x,y) 断开后分别包含 xy 的树的结点数。为了消除断边的影响,在询问后我们再次连接边 (x,y)

题目中的操作既有连边,又有删边,还保证在任意时刻都是一棵森林,我们不由得想到用 LCT 来维护。但是这题中 LCT 维护的是子树的大小,不像我们印象中的维护一条链的信息,而 LCT 的构造 认父不认子 ,不方便我们直接进行子树的统计。怎么办呢?

方法是统计一个结点 x 所有虚儿子(即父亲为 x ,但 x 在 splay 中的左右儿子并不包含 x )所代表的子树的贡献。

定义 siz2[x] 为结点 x 的所有虚儿子代表的子树的结点数, siz[x] 为 结点 x 子树中的结点数。

不同于以往我们维护 Splay 中子树结点个数的方法,我们在计算结点 x 子树中的结点数时,还要加上 siz2[x] ,即

void maintain(int x){clear(0);if(x)siz[x]=siz[ch[x][0]]+1+siz[ch[x][1]]+siz2[x];}

而且在我们 改变 Splay 的形态 (即改变一个结点在 Splay 上的左右儿子指向时),需要及时修改 siz2[x] 的值。

rotate,splay 操作中,我们都只是改变了 Splay 中结点的相对位置,没有改变任意一条边的虚实情况,所以不对 siz2[x] 进行任何修改。

access 操作中,在每次 splay 完后,都会改变刚刚 splay 完的结点的右儿子,即该结点与其原右儿子的连边和该节点和新右儿子的连边的虚实情况发生了变化,我们需要加上新变成虚边所连的子树的贡献,减去刚刚变成实边所连的子树的贡献。代码如下:

void access(int x)
{
    for(int f=0;x;f=x,x=fa[x])
    splay(x),siz2[x]+=siz[ch[x][1]]-siz[f],ch[x][1]=f,maintain(x);
}

makeroot,find 操作中,我们都只是调用了之前的函数或者在 Splay 上条边,并不用做任何修改。

在连接两点时,我们修改了一个结点的父亲。我们需要在父亲结点的 siz2 值中加上新子结点的子树大小贡献。

st.makeroot(x);
st.makeroot(y);
st.fa[x]=y;
st.siz2[y]+=st.siz[x];

在断开一条边时,我们只是删除了 Splay 上的一条实边, maintain 操作会维护这些信息,不需要做任何修改。

代码修改的细节讲完了,现在我们总结一下 LCT 维护子树信息的要求和方法:

  1. 维护的信息要有 可减性 ,如子树结点数,子树权值和,但不能直接维护子树最大最小值,因为在将一条虚边变成实边时要排除原先虚边的贡献。

  2. 新建一个附加值存储虚子树的贡献,在统计时将其加入本结点答案,在改变边的虚实时及时维护。

  3. 其余部分同普通 LCT ,在统计子树信息时一定将其作为根节点。

  4. 如果维护的信息没有可减性,如维护区间最值,可以对每个结点开一个平衡树维护结点的虚子树中的最值。

代码展示

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=100010;
typedef long long ll;
struct Splay
{
    int ch[maxn][2],fa[maxn],siz[maxn],siz2[maxn],tag[maxn];
    void clear(int x){ch[x][0]=ch[x][1]=fa[x]=siz[x]=siz2[x]=tag[x]=0;}
    int getch(int x){return ch[fa[x]][1]==x;}
    int isroot(int x){return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;}
    void maintain(int x){clear(0);if(x)siz[x]=siz[ch[x][0]]+1+siz[ch[x][1]]+siz2[x];}
    void pushdown(int x)
    {
        if(tag[x])
        {
            if(ch[x][0])swap(ch[ch[x][0]][0],ch[ch[x][0]][1]),tag[ch[x][0]]^=1;
            if(ch[x][1])swap(ch[ch[x][1]][0],ch[ch[x][1]][1]),tag[ch[x][1]]^=1;
            tag[x]=0;
        }
    }
    void update(int x)
    {
        if(!isroot(x))update(fa[x]);
        pushdown(x);
    }
    void rotate(int x)
    {
        int y=fa[x],z=fa[y],chx=getch(x),chy=getch(y);
        fa[x]=z;if(!isroot(y))ch[z][chy]=x;
        ch[y][chx]=ch[x][chx^1];fa[ch[x][chx^1]]=y;
        ch[x][chx^1]=y;fa[y]=x;
        maintain(y);maintain(x);maintain(z);
    }
    void splay(int x)
    {
        update(x);
        for(int f=fa[x];f=fa[x],!isroot(x);rotate(x))
        if(!isroot(f))rotate(getch(x)==getch(f)?f:x);
    }
    void access(int x)
    {
        for(int f=0;x;f=x,x=fa[x])
        splay(x),siz2[x]+=siz[ch[x][1]]-siz[f],ch[x][1]=f,maintain(x);
    }
    void makeroot(int x)
    {
        access(x);splay(x);
        swap(ch[x][0],ch[x][1]);
        tag[x]^=1;
    }
    int find(int x)
    {
        access(x);splay(x);
        while(ch[x][0])x=ch[x][0];
        splay(x);
        return x;
    }
}st;
int n,q,x,y;
char op;
int main()
{
    scanf("%d%d",&n,&q);
    while(q--)
    {
        scanf(" %c%d%d",&op,&x,&y);
        if(op=='A')
        {
            st.makeroot(x);
            st.makeroot(y);
            st.fa[x]=y;
            st.siz2[y]+=st.siz[x];
        }
        if(op=='Q')
        {
            st.makeroot(x);st.access(y);st.splay(y);
            st.ch[y][0]=st.fa[x]=0;
            st.maintain(x);
            st.makeroot(x);st.makeroot(y);
            printf("%lld\n",(ll)(st.siz[x]*st.siz[y]));
            st.makeroot(x);
            st.makeroot(y);
            st.fa[x]=y;
            st.siz2[y]+=st.siz[x];
        }
        //for(int i=1;i<=n;i++)printf("%d %d %d %d %d %d\n",i,st.ch[i][0],st.ch[i][1],st.fa[i],st.siz[i],st.siz2[i]);
    }
    return 0;
}