题解 P5352 【Terrible Homework】

· · 题解

这题真是糟糕啊

重构一遍就过了

解题思路

观察到题目维护的是动态树上的树链信息问题,不难想到用 LCT 维护。

考虑维护修改操作。由于按位异或操作对按位与,按位或和求和操作不具有分配率,所以我们讨论每一位的情况。

不难发现,异或二进制拆位后具有分配率。如果异或的 v 值在二进制上的第 i 位有值的话,会将这棵子树里的所有值的这一位上都会反转(即 0 变为 11 变为 0 )。

我们以辅助树子树中每一位上 01 的数量来记录树链的信息并处理查询。

进行修改操作时,提取出 xy 的树链到一棵树上,将根节点上对应位的 01 的数量交换,然后打上标记。在进行 splay 操作前,先下传所有路径上会经过的标记。

查询树链信息时,我们同样提取出 xy 之间的树链。

对于每一位,如果子树中这一位上没有 0 ,说明这一位上 and 值为 1

如果子树中这一位上至少有一个 1 ,说明这一位上 or 值为 1

如果子树中这一位上 1 的个数为奇数,说明这一位上 xor 的值为 1

求树链和时,只需求出所有位上 1 的个数乘这个数位的大小的总和。这里需要使用 long long

时间复杂度 O(n\log _2 n\log_2 A) ,空间复杂度 O(n\log_2 A)

参考代码

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=100010;
int n,q,x,y,z;
char op;
typedef long long ll;
ll ans;
struct Splay
{
    int ch[maxn][2],fa[maxn],v[maxn],tag[maxn],xr[maxn],zero[maxn][31],ones[maxn][31];
    void clear(int x)
    {
        ch[x][0]=ch[x][1]=fa[x]=tag[x]=xr[x]=0;
        for(int i=0;i<=30;i++)zero[x][i]=ones[x][i]=0;
    }
    int getch(int x){return ch[fa[x]][1]==x;}
    int isroot(int x){return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;}
    void maintain(int x)
    {
        for(int i=0;i<=30;i++)zero[x][i]=ones[x][i]=0;
        for(int i=0;i<=30;i++)if(v[x]&(1<<i))ones[x][i]++;else zero[x][i]++;
        if(ch[x][0])for(int i=0;i<=30;i++)zero[x][i]+=zero[ch[x][0]][i],ones[x][i]+=ones[ch[x][0]][i];
        if(ch[x][1])for(int i=0;i<=30;i++)zero[x][i]+=zero[ch[x][1]][i],ones[x][i]+=ones[ch[x][1]][i];
    }
    void pushdown(int x)
    {
        if(tag[x])
        {
            if(ch[x][0])swap(ch[ch[x][0]][0],ch[ch[x][0]][1]),tag[ch[x][0]]^=1;
            if(ch[x][1])swap(ch[ch[x][1]][0],ch[ch[x][1]][1]),tag[ch[x][1]]^=1;
            tag[x]=0;
        }
        if(xr[x])
        {
            if(ch[x][0])
            {
                for(int i=0;i<=30;i++)if(xr[x]&(1<<i))swap(zero[ch[x][0]][i],ones[ch[x][0]][i]);
                v[ch[x][0]]^=xr[x];xr[ch[x][0]]^=xr[x];
            }
            if(ch[x][1])
            {
                for(int i=0;i<=30;i++)if(xr[x]&(1<<i))swap(zero[ch[x][1]][i],ones[ch[x][1]][i]);
                v[ch[x][1]]^=xr[x];xr[ch[x][1]]^=xr[x];
            }
            xr[x]=0;
        }
    }
    void update(int x)
    {
        if(!isroot(x))update(fa[x]);
        pushdown(x);
    }
    void rotate(int x)
    {
        int y=fa[x],z=fa[y],chx=getch(x),chy=getch(y);
        fa[x]=z;if(!isroot(y))ch[z][chy]=x;
        ch[y][chx]=ch[x][chx^1];fa[ch[x][chx^1]]=y;
        ch[x][chx^1]=y;fa[y]=x;
        maintain(y);maintain(x);
    }
    void splay(int x)
    {
        update(x);
        for(int f=fa[x];f=fa[x],!isroot(x);rotate(x))
        if(!isroot(f))rotate(getch(x)==getch(f)?f:x);
    }
    void access(int x)
    {
        for(int f=0;x;f=x,x=fa[x])
        splay(x),ch[x][1]=f,maintain(x);
    }
    void makeroot(int x)
    {
        access(x);splay(x);
        tag[x]^=1;
        swap(ch[x][0],ch[x][1]);
    }
}st;
int main()
{
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++)scanf("%d",&st.v[i]),st.maintain(i);
    while(q--)
    {
        scanf(" %c%d%d",&op,&x,&y);
        if(op=='L')
        {
            st.makeroot(x);st.fa[x]=y;
        }
        if(op=='C')
        {
            st.makeroot(x);st.access(y);st.splay(y);
            st.fa[x]=st.ch[y][0]=0;
            st.maintain(y);
        }
        if(op=='U')
        {
            scanf("%d",&z);
            st.makeroot(x);st.access(y);st.splay(y);
            for(int i=0;i<=30;i++)if(z&(1<<i))swap(st.zero[y][i],st.ones[y][i]);
            st.xr[y]^=z;st.v[y]^=z;
        }
        if(op=='A')
        {
            st.makeroot(x);st.access(y);st.splay(y);
            ans=0;
            for(int i=0;i<=30;i++)if(!st.zero[y][i])ans+=(1<<i);
            printf("%lld\n",ans);
        }
        if(op=='O')
        {
            st.makeroot(x);st.access(y);st.splay(y);
            ans=0;
            for(int i=0;i<=30;i++)if(st.ones[y][i])ans+=(1<<i);
            printf("%lld\n",ans); 
        }
        if(op=='X')
        {
            st.makeroot(x);st.access(y);st.splay(y);
            ans=0;
            for(int i=0;i<=30;i++)if(st.ones[y][i]&1)ans+=(1<<i);
            printf("%lld\n",ans);
        }
        if(op=='S')
        {
            st.makeroot(x);st.access(y);st.splay(y);
            ans=0;
            for(int i=0;i<=30;i++)ans+=(ll)(1<<i)*(ll)st.ones[y][i];
            printf("%lld\n",ans);
        }
    }
    return 0;
}