P9555 「CROI · R1」浣熊的阴阳鱼

· · 题解

cnblogs

树链剖分真可爱。

题目链接

  • 给出 n 个点的树,点有点权 a_i(a_i\in\{0,1\})。支持 q 次操作:

看到树上修改和路径查询,首先想到树剖。我们发现 |S|\le 2,且 S 内元素的顺序不影响答案,因此我们可以用一个二元组 (i,j)(0\le i\le j\le 2) 记录 S 的状态(0,1 表示放了什么元素,2 表示没有元素)。为了方便表述,将 2 也认为是 S 内的元素,即强制 |S|=2

分析一下询问,相当于已经给定了初始状态 S=\{a_u,2\}。再根据树剖的思想,我们要快速查询一条重链上的信息,即快速查询那条链对应的区间的信息。考虑使用线段树维护。具体地,对于线段树上的一个节点 x(设对应的区间为 [l,r]),记 f_{x,i,j} 表示从 l 位置开始走,经过 l 位置后 S=\{i,j\},走到 r 时的得分,类似地 g_{x,i,j} 表示从 r 走到 l 的得分。还需要记录 LtoR_{x,i,j} 表示从 l 位置开始走,经过 l 位置后 S=\{i,j\},到达 rS 的状态,同理还有 RtoL_{x,i,j} 表示从 r 到达 lS 的状态。注意,我们强调了初始 S 是经过起点后的状态,意味着上述的 S已经受起点的影响,即统计答案的时候不再受到起点的影响(因为已经受过了)。类似地,强调 LtoRRtoL 是到达端点后的状态,说明信息已经受到终点的影响,因为状态的定义里保证了它受到起点的影响,所以同时保证了统计了完整的信息。

考虑如何合并区间信息,我们发现需要快速计算出已经到了一个区间末尾,下一步走到另一半区间开头时,S 的状态,因此还需要记录 lc_{x},rc_{x} 来存储 a_la_r。所以线段树的节点是张这样的:

#define pii pair<int,int>
struct node{//变量名有部分不同。
    int cnt_l[3][3],cnt_r[3][3],lc,rc;//f,g,lc,rc。
    pii status_l[3][3],status_r[3][3];//LtoR,RtoL。
}seg[N<<2];

然后计算状态可以通过以下的函数实现(我写的比较暴力,直接枚举 12 种情况分类讨论):

#define ppi pair<pii,int>
#define mp make_pair
ppi get_status(pii p,int x){//第二维表示得分增量。
    if(p==mp(0,0)&&!x){
        return mp(p,0);
    }
    if(p==mp(0,0)&&x){
        return mp(mp(0,2),1);
    }
    if(p==mp(0,1)&&!x){
        return mp(mp(0,2),1);
    }
    if(p==mp(0,1)&&x){
        return mp(mp(1,2),1);
    }
    if(p==mp(0,2)&&!x){
        return mp(mp(0,0),0);
    }
    if(p==mp(0,2)&&x){
        return mp(mp(2,2),1);
    }
    if(p==mp(1,1)&&!x){
        return mp(mp(1,2),1);
    }
    if(p==mp(1,1)&&x){
        return mp(mp(1,1),0);
    }
    if(p==mp(1,2)&&!x){
        return mp(mp(2,2),1);
    }
    if(p==mp(1,2)&&x){
        return mp(mp(1,1),0);
    }
    return mp(mp(x,2),0);//空集就直接放。
}

区间信息具体合并方法为,先从计算当前状态走到区间末尾的信息,然后计算跨过区间后的状态,以及计算跨区间这一步对得分的贡献。然后再从另一半区间,以跨区间后的状态开始走,计算得分。从当前起点走到另一个端点的状态,就是从另一半区间,以跨区间后的状态开始走,走到那个端点时的状态。代码如下:

#define fi first 
#define se second
node merge(node l,node r){
    node ret;
    ret.lc=l.lc;
    ret.rc=r.rc;
    for(int i=0;i<=2;++i){
        for(int j=i;j<=2;++j){
            ppi l_start=get_status(r.status_r[i][j],l.rc),r_start=get_status(l.status_l[i][j],r.lc);
            ret.cnt_l[i][j]=l.cnt_l[i][j]+r_start.se+r.cnt_l[r_start.fi.fi][r_start.fi.se];
            ret.status_l[i][j]=r.status_l[r_start.fi.fi][r_start.fi.se];
            ret.cnt_r[i][j]=r.cnt_r[i][j]+l_start.se+l.cnt_r[l_start.fi.fi][l_start.fi.se];
            ret.status_r[i][j]=l.status_r[l_start.fi.fi][l_start.fi.se];
        }
    }
    return ret;
}

对于长度为 1 的区间,有初始化:

seg[x].lc=seg[x].rc=b[l];//b[l] 是那个点的元素值。
for(int i=0;i<=2;++i){
    for(int j=i;j<=2;++j){
        seg[x].cnt_l[i][j]=seg[x].cnt_r[i][j]=0;//端点已经考虑过且没有遇到新元素,对得分无贡献。
        seg[x].status_l[i][j]=seg[x].status_r[i][j]=mp(i,j);//起点终点相同,受到起点的影响即受到了终点的影响。
    }
}

那么单点修改也很好维护:

seg[x].lc^=1;
seg[x].rc^=1;

查询就是跳链查询。注意合并信息的顺序,具体可以参考 GSS7。

时间复杂度为 \mathcal{O}(q\log ^2n),空间复杂度为 \mathcal{O}(n)

评测记录

代码

#include<bits/stdc++.h>
#define ls(x) ((x)<<1)
#define rs(x) ((x)<<1|1)
#define fi first 
#define se second 
#define pii pair<int,int> 
#define ppi pair<pii,int>
#define mp make_pair
using namespace std;
const int N=1e5+5;
int n,q,a[N],siz[N],dep[N],top[N],hson[N],fa[N],dfn[N],id,b[N];
vector<int>g[N];
struct node{
    int cnt_l[3][3],cnt_r[3][3],lc,rc;//l to r;r to l;
    pii status_l[3][3],status_r[3][3];//start l;start r;
}seg[N<<2];
void dfs1(int u){
    siz[u]=1;
    for(int v:g[u]){
        if(v!=fa[u]){
            dep[v]=dep[u]+1;
            fa[v]=u;
            dfs1(v);
            siz[u]+=siz[v];
        }
    }
}
void dfs2(int u){
    for(int v:g[u]){
        if(v!=fa[u]){
            if((siz[v]<<1)>siz[u]){
                hson[u]=v;
                top[v]=top[u];
            }else{
                top[v]=v;
            }
            dfs2(v);
        }
    }
}
void dfs3(int u){
    dfn[u]=++id;
    b[id]=a[u];
    if(hson[u]){
        dfs3(hson[u]);
    }
    for(int v:g[u]){
        if(v!=fa[u]&&v!=hson[u]){
            dfs3(v);
        }
    }
}
ppi get_status(pii p,int x){
    if(p==mp(0,0)&&!x){
        return mp(p,0);
    }
    if(p==mp(0,0)&&x){
        return mp(mp(0,2),1);
    }
    if(p==mp(0,1)&&!x){
        return mp(mp(0,2),1);
    }
    if(p==mp(0,1)&&x){
        return mp(mp(1,2),1);
    }
    if(p==mp(0,2)&&!x){
        return mp(mp(0,0),0);
    }
    if(p==mp(0,2)&&x){
        return mp(mp(2,2),1);
    }
    if(p==mp(1,1)&&!x){
        return mp(mp(1,2),1);
    }
    if(p==mp(1,1)&&x){
        return mp(mp(1,1),0);
    }
    if(p==mp(1,2)&&!x){
        return mp(mp(2,2),1);
    }
    if(p==mp(1,2)&&x){
        return mp(mp(1,1),0);
    }
    return mp(mp(x,2),0);
}
node merge(node l,node r){
    node ret;
    ret.lc=l.lc;
    ret.rc=r.rc;
    for(int i=0;i<=2;++i){
        for(int j=i;j<=2;++j){
            ppi l_start=get_status(r.status_r[i][j],l.rc),r_start=get_status(l.status_l[i][j],r.lc);
            ret.cnt_l[i][j]=l.cnt_l[i][j]+r_start.se+r.cnt_l[r_start.fi.fi][r_start.fi.se];
            ret.status_l[i][j]=r.status_l[r_start.fi.fi][r_start.fi.se];
            ret.cnt_r[i][j]=r.cnt_r[i][j]+l_start.se+l.cnt_r[l_start.fi.fi][l_start.fi.se];
            ret.status_r[i][j]=l.status_r[l_start.fi.fi][l_start.fi.se];
        }
    }
    return ret;
}
void build(int x,int l,int r){
    if(l==r){
        seg[x].lc=seg[x].rc=b[l];
        for(int i=0;i<=2;++i){
            for(int j=i;j<=2;++j){
                seg[x].cnt_l[i][j]=seg[x].cnt_r[i][j]=0;
                seg[x].status_l[i][j]=seg[x].status_r[i][j]=mp(i,j);
            }
        }
        return;
    }
    int mid=(l+r)>>1;
    build(ls(x),l,mid);
    build(rs(x),mid+1,r);
    seg[x]=merge(seg[ls(x)],seg[rs(x)]);
}
void modify(int x,int l,int r,int k){
    if(l==r){
        seg[x].lc^=1;
        seg[x].rc^=1;
        return;
    }
    int mid=(l+r)>>1;
    if(k<=mid){
        modify(ls(x),l,mid,k);
    }else{
        modify(rs(x),mid+1,r,k);
    }
    seg[x]=merge(seg[ls(x)],seg[rs(x)]);
}
node query(int x,int l,int r,int ql,int qr){
    if(ql<=l&&r<=qr){
        return seg[x];
    }
    int mid=(l+r)>>1;
    if(qr<=mid){
        return query(ls(x),l,mid,ql,qr);
    }
    if(ql>mid){
        return query(rs(x),mid+1,r,ql,qr);
    }
    return merge(query(ls(x),l,mid,ql,qr),query(rs(x),mid+1,r,ql,qr));
}
node pathquery(int x,int y){
    node info_x,info_y,temp;
    bool empty_x=1,empty_y=1;
    while(top[x]!=top[y]){
        if(dep[top[x]]>dep[top[y]]){
            temp=query(1,1,n,dfn[top[x]],dfn[x]);
            if(empty_x){
                info_x=temp;
                empty_x=0;
            }else{
                info_x=merge(temp,info_x);
            }
            x=fa[top[x]];
        }else{
            temp=query(1,1,n,dfn[top[y]],dfn[y]);
            if(empty_y){
                info_y=temp;
                empty_y=0;
            }else{
                info_y=merge(temp,info_y);
            }
            y=fa[top[y]];
        }
    }
    if(dep[x]<dep[y]){
        temp=query(1,1,n,dfn[x],dfn[y]);
        if(!empty_x){
            swap(info_x.cnt_l,info_x.cnt_r);
            swap(info_x.status_l,info_x.status_r);
            if(!empty_y){
                return merge(info_x,merge(temp,info_y));
            }
            return merge(info_x,temp);
        }
        if(!empty_y){
            return merge(temp,info_y);
        }
        return temp;
    }else{
        temp=query(1,1,n,dfn[y],dfn[x]);
        if(!empty_x){
            info_x=merge(temp,info_x);
            swap(info_x.cnt_l,info_x.cnt_r);
            swap(info_x.status_l,info_x.status_r);
            if(!empty_y){
                return merge(info_x,info_y);
            }
            return info_x;
        }
        swap(temp.cnt_l,temp.cnt_r);
        swap(temp.status_l,temp.status_r);
        if(!empty_y){
            return merge(temp,info_y);
        }
        return temp;
    }
}
signed main(){
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
    cin>>n>>q;
    for(int i=1;i<=n;++i){
        cin>>a[i];
    }
    for(int i=1,u,v;i<n;++i){
        cin>>u>>v;
        g[u].emplace_back(v);
        g[v].emplace_back(u);
    }
    dfs1(1);
    top[1]=1;
    dfs2(1);
    dfs3(1);
    build(1,1,n);
    for(int op,u,v,i=1;i<=q;++i){
        cin>>op>>u;
        if(op==1){
            a[u]^=1;
            modify(1,1,n,dfn[u]);
        }else{
            cin>>v;
            node ans=pathquery(u,v);
            cout<<ans.cnt_l[a[u]][2]<<'\n';
        }
    }
    return 0;
}