NOI2021 d1t1

· · 题解

我们可以把对「边」的询问转到「点」上!

具体来说,对于一条边 (u,v),我们认为:

其中,\text{val}(x) 表示点 x 的权值,这个权值我们可以自己定义~

这样一来,修改操作就变成了:

那么查询操作实际上也就变成了:查询区间内有多少个相邻的点对,满足这两个点的权值相同。

因此我们先做一次树剖转为区间问题,然后线段树维护即可=w=

具体地,对于线段树上的每个节点,我们维护一下这个节点的左/右端的权值以及答案,合并时只需要把左右儿子的答案加起来,然后特判一下「左儿子的右端点」与「右儿子的左端点」是否相同即可。

#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<vector>
#include<cstdio>

#define lson(o) (o<<1)
#define rson(o) (o<<1|1)
const int MN=1e5+5;

using namespace std;

int w[MN],val[MN],dfn[MN],sz[MN],fa[MN],top[MN],hson[MN],dep[MN];
vector<int>G[MN];

inline int read(){
    int x=0,f=1;char c=getchar();
    for(;(c<'0'||c>'9');c=getchar()){if(c=='-')f=-1;}
    for(;(c>='0'&&c<='9');c=getchar())x=x*10+(c&15);
    return x*f;
}

int dfs1(int u,int d){
    dep[u]=d,sz[u]=1,hson[u]=0;
    for(int i=0,s=G[u].size();i<s;i++){
        int to=G[u][i];
        if(to==fa[u])continue;
        fa[to]=u;
        sz[u]+=dfs1(to,d+1);
        if(sz[to]>sz[hson[u]])hson[u]=to;
    }
    return sz[u];
}

int tot=0;
void dfs2(int u,int tp){
    dfn[u]=++tot,top[u]=tp,val[dfn[u]]=w[u];
    if(hson[u])dfs2(hson[u],tp);
    for(int i=0,s=G[u].size();i<s;i++){
        int to=G[u][i];
        if(to==fa[u]||to==hson[u])continue;
        dfs2(to,to);
    }
}

struct SMT{

    int d[MN<<2],L[MN<<2],R[MN<<2],lz[MN<<2];//d 为此区间答案,L/R 表示左右端点,lz是懒标记。

    inline void pushup(int o){
        d[o]=d[lson(o)]+d[rson(o)]+(L[rson(o)]==R[lson(o)]);
        L[o]=L[lson(o)],R[o]=R[rson(o)];
    }

    inline void build(int l,int r,int o){
        lz[o]=0;
        if(l==r){
            d[o]=0;
            L[o]=R[o]=val[l];
            return ;
        }
        int mid=(l+r)>>1;
        build(l,mid,lson(o));
        build(mid+1,r,rson(o));
        pushup(o);
    }

    inline void pushdown(int ql,int qr,int o){
        if(!lz[o])return ;
        int mid=(ql+qr)>>1;
        d[lson(o)]=mid-ql,d[rson(o)]=qr-mid-1;
        lz[lson(o)]=lz[o],lz[rson(o)]=lz[o];
        L[lson(o)]=R[lson(o)]=L[rson(o)]=R[rson(o)]=lz[o];
        lz[o]=0;
    }

    inline int query(int l,int r,int ql,int qr,int o){
        if(l<=ql&&qr<=r)return d[o];
        pushdown(ql,qr,o);
        int mid=(ql+qr)>>1,ans=0;
        if(l<=mid)ans+=query(l,r,ql,mid,lson(o));
        if(r>mid)ans+=query(l,r,mid+1,qr,rson(o));
        if(l<=mid&&r>mid&&R[lson(o)]==L[rson(o)])ans++;
        return ans;
    }

    inline void modify(int l,int r,int k,int ql,int qr,int o){
        if(l<=ql&&qr<=r){
            L[o]=R[o]=lz[o]=k;
            d[o]=qr-ql;
            return ;
        }
        pushdown(ql,qr,o);
        int mid=(ql+qr)>>1;
        if(l<=mid)modify(l,r,k,ql,mid,lson(o));
        if(r>mid)modify(l,r,k,mid+1,qr,rson(o));
        pushup(o);
    }

    inline int kwii(int pos,int ql,int qr,int o){
        if(ql==qr)return L[o];
        pushdown(ql,qr,o);
        int mid=(ql+qr)>>1;
        if(pos<=mid)return kwii(pos,ql,mid,lson(o));
        else return kwii(pos,mid+1,qr,rson(o));
    }

};

SMT tree;
int n,m;

int cnt=0;
void change(int x,int y){
    ++cnt;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        tree.modify(dfn[top[x]],dfn[x],cnt,1,n,1);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    tree.modify(dfn[x],dfn[y],cnt,1,n,1);
}

int queryans(int x,int y){
    int res=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        res+=tree.query(dfn[top[x]],dfn[x],1,n,1);
        if(tree.kwii(dfn[fa[top[x]]],1,n,1)==tree.kwii(dfn[top[x]],1,n,1))res++;//这里需要特判一下链顶元素与其父亲是否相同
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    res+=tree.query(dfn[x],dfn[y],1,n,1);
    return res;
}

void solve(){
    cnt=tot=0;
    n=read(),m=read();
    for(int i=1;i<=n;i++)w[i]=++cnt;//由于初始均为轻边,所以我们需要保证 w[i]!=w[j],那么直接赋值成 1~n 就行了qwq

    for(int i=1;i<n;i++){
        int u=read(),v=read();
        G[u].push_back(v),G[v].push_back(u);
    }

    dfs1(1,1);
    dfs2(1,1);

    tree.build(1,n,1);

    while(m--){
        int op=read(),x=read(),y=read();
        if(op==1)change(x,y);
        if(op==2)cout<<queryans(x,y)<<endl;
    }

    for(int i=1;i<=n;i++)G[i].clear();
}

int _;

signed main(void){

    // freopen("edge.in","r",stdin);
    // freopen("edge.out","w",stdout);

    cin>>_;
    while(_--)solve();

    return 0;
}

那么就这样=w=