题解 P5556 【圣剑护符】

· · 题解

树链剖分·线性基·真·板子题

SJY说:这是一道【哔~】题

首先,我们看到了异或操作,那毫无疑问,就是用线性基了

我们知道,线性基有如下性质:

那么很显然,如果一个数无法被插入,则说明这个数可以\oplus线性基中的数从而变成0,换句或说:

如果一个数无法被插入,那么它一定可以由已插入的数异或得到

那么我们就可以用这个性质来判断一组数据是否存在两个不相等的子集,使得两个子集的值相同

然后,我们又发现一个很有趣的性质:线性基中一共有\lceil log_2(v)+1\rceil个位置,所以最多只能插入31个数,对于后面插入的数,一定可以由前面的数异或得到,于是我们又得到了一条有用的性质:

如果两点间简单路径上的点数>31,那么一定有集合的值相等

那么小蒟蒻又有问题了:按照上面的思路,就需要快速求出树上两点的距离,还要快速修改,蒟蒻做不到啊

于是这时,树链剖分就出现了。不仅能快速求LCA,套上线段树快速修改,完全符合我们的需要

借鉴完树链剖分,这道题就愉快地做完了

#pragma optimize(2)
#include<bits/stdc++.h>
using namespace std;
template<typename T>
inline void read(T &x){
    x=0;char c=getchar();bool f=false;
    for(;!isdigit(c);c=getchar())f!=c=='-';
    for(;isdigit(c);c=getchar())x=x*10+c-'0';
    if(f)x=-x;
}
template<typename T ,typename ...Arg>
inline void read(T &x,Arg &...args){
    read(x);read(args...);
}
template<typename T>
inline void write(T x){
    if(x<0)putchar('-'),x=-x;
    if(x>=10)write(x/10);
    putchar(x%10+'0');
}
char get(){char c=getchar();while(!isalpha(c))c=getchar();return c;}
const int maxn=1e5+100;
struct Graph{
    struct node{int v,nxt;}e[maxn*2];
    int cnt,head[maxn];
    inline void add(int x,int y){e[++cnt]={y,head[x]};head[x]=cnt;}
    #define For(G,x) for(int ___=(G).head[x];___;___=(G).e[___].nxt)
    #define v(G) (G).e[___].v
}G;
//----------树----------
struct Segment_Tree{
    struct node{
        int num,l,r;
        int tag;
    }t[maxn<<2];
    void build(int x,int l,int r){
        t[x].l=l,t[x].r=r;
        t[x].num=0,t[x].tag=0;
        if(l==r)return;
        int mid=l+r>>1;
        build(x<<1,l,mid);
        build(x<<1|1,mid+1,r);
    }
    void pushdown(int x){
        if(t[x].tag){
            t[x<<1].tag^=t[x].tag;
            t[x<<1].num^=t[x].tag;
            t[x<<1|1].tag^=t[x].tag;
            t[x<<1|1].num^=t[x].tag;
            t[x].tag=0;
        }
    }
    void update(int x,int l,int r,int val){
        if(t[x].l>r||t[x].r<l)return;
        if(l<=t[x].l&&t[x].r<=r){t[x].num^=val;t[x].tag^=val;return;}
        pushdown(x);update(x<<1,l,r,val);update(x<<1|1,l,r,val);
    }
    int query(int x,int pos){
        if(t[x].l==t[x].r)return t[x].num;
        pushdown(x);
        int mid=t[x].l+t[x].r>>1;
        if(pos<=mid)return query(x<<1,pos);
        else return query(x<<1|1,pos);
    }
    void print(int x){
        if(t[x].l==t[x].r)printf("%d ",t[x].num);
        else pushdown(x),print(x<<1),print(x<<1|1);
    }
}ST;
//----------线段树----------
int fa[maxn];//父节点
int deep[maxn];//深度
int size[maxn];//子树大小
int top[maxn];//链顶
int id[maxn];//dfs序
int v[maxn];//权值 
int cnt=0;
void dfs1(int x,int f){
    //处理: fa deep size
    fa[x]=f,deep[x]=deep[f]+1;size[x]=1;
    For(G,x)if(v(G)!=f)dfs1(v(G),x),size[x]+=size[v(G)];
} 
void dfs2(int x,int f){
    //处理: top id
    id[x]=++cnt;top[x]=f;ST.update(1,cnt,cnt,v[x]);
    int MAX=-1,s=-1;
    For(G,x) if(v(G)!=fa[x]&&size[v(G)]>MAX) MAX=size[v(G)], s=v(G);
    if(MAX!=-1)dfs2(s,f);
    For(G,x) if(v(G)!=fa[x]&&v(G)!=s)dfs2(v(G),v(G));
}
void update(int x,int y,int z){
    while(top[x]!=top[y]){
        if(deep[top[x]]<deep[top[y]])
            swap(x,y);
        ST.update(1,id[top[x]],id[x],z);
        x=fa[top[x]];
    }
    if(id[x]>id[y])
        swap(x,y);
    ST.update(1,id[x],id[y],z);
}
int LCA(int u,int v){  
    while(top[u]!=top[v])/*u、v不在同一条重链上时*/{
        if(deep[top[u]]>deep[top[v]])//将深度大的上提
            u=fa[top[u]];
        else
            v=fa[top[v]];
    }
    if(deep[u]<deep[v])//返回u、v中在较上方的那个
        return u;
    return v;
}
//----------树链剖分----------
const int max_wei=31;
struct leaner_basis{
    int b[max_wei];
    void init(){memset(b,0,sizeof b);}
    bool insert(int x){
        for(int i=max_wei-1;i>=0;i--){
            if(!(x&(1<<i)))continue;
            if(!b[i]){b[i]=x;return true;}
            x^=b[i];
        }
        return false;
    }
}B;
//----------线性基---------- 
int n,q,x,y,z;
signed main(){
    read(n,q);
    for(int i=1;i<=n;i++)
        read(v[i]);
    for(int i=1;i<n;i++)
        read(x,y),G.add(x,y),G.add(y,x);
    ST.build(1,1,n);
    dfs1(1,-1);dfs2(1,1);
    for(int i=1;i<=q;i++){
        char opt=get();
        if(opt=='Q'){
            read(x,y);
            int lca=LCA(x,y);
            int dis=deep[x]+deep[y]-deep[lca]*2+1;
            if(dis>30)printf("YES\n");
            else{
                B.init();
                bool flag=false;
                if(!B.insert(ST.query(1,id[lca])))flag=true;
                if(!flag)
                    while(x!=lca){
                        if(!B.insert(ST.query(1,id[x]))){flag=true;break;}
                        x=fa[x];
                    }
                if(!flag)
                    while(y!=lca){
                        if(!B.insert(ST.query(1,id[y]))){flag=true;break;}
                        y=fa[y];
                    }
                printf(flag?"YES\n":"NO\n");
            }
        }
        else{
            read(x,y,z);
            update(x,y,z);
        }
    }
}