P9631 [ICPC2020 Nanjing R] Just Another Game of Stones 题解

· · 题解

题目大意

给定 n 个石堆,第 i 个石堆的石子数为 a_i,有 q 次操作:

1 l r x 表示令所有 a_i=\max(a_i,x),i\in[l,r]

2 l r x 表示用石堆 [l,r] 和一个石子数为 x 的石堆进行 Nim 游戏,求出第一次先手取完石子后游戏变为后手必败局面的可操作总数。

(如果你还不了解 Nim 游戏,点这里)。

思路

我们知道,Nim 游戏中所有石堆的石子数异或和为 0 的局面为先手必败局面,所以第一步需要保证取完石子后剩下石堆的石子数异或和为 0(相当于把先手必败的局面给对手)才能赢得游戏。

设当前异或和为 s,某一石堆里的石子数为 a_i,我们需要取走 a_i-a_i\oplus s 个石子才能保证取完后异或和为 0(也就是令 a_i=a_i\oplus s),那么当 a_i\ge a_i\oplus s 时就可以对该石堆进行一次操作,因此操作二所求答案就是 \sum_{i=l}^r[a_i\ge a_i\oplus s]。同时我们知道异或本质上是不进位的加法,考虑 s 的最高位 1,若 a_i 在此位也是 1,则 a_i\ge a_i\oplus s 恒成立,反之恒不成立。

可以用吉司机线段树(Segment Tree Beats)维护操作一(取 \max),同时对于线段树的每个节点都维护区间的异或和,并且记录每一种二进制位为1的数字的数量,向上合并时直接暴力即可。

(如果你还不了解吉司机线段树,点这里)。

AC 代码

#include<iostream>
#define mid ((l+r)>>1)
#define ls (rt<<1)
#define rs (rt<<1|1)
using namespace std;
const int N=200005;
struct node
{
    int cnt[31],ccnt,p,op,sum,tag;
}tr[N<<2];
int a[N],n,m;
struct SegmentTreeBeats
{
    void pushup(int rt)
    {
        if(tr[ls].p<tr[rs].p)
        {
            tr[rt].p=tr[ls].p;
            tr[rt].ccnt=tr[ls].ccnt;
            tr[rt].op=min(tr[ls].op,tr[rs].p);
            tr[rt].sum=tr[ls].sum^tr[rs].sum;
            for(int i=0;i<=30;i++)
                tr[rt].cnt[i]=tr[ls].cnt[i]+tr[rs].cnt[i];
        }
        else if(tr[ls].p>tr[rs].p)
        {
            tr[rt].p=tr[rs].p;
            tr[rt].ccnt=tr[rs].ccnt;
            tr[rt].op=min(tr[rs].op,tr[ls].p);
            tr[rt].sum=tr[rs].sum^tr[ls].sum;
            for(int i=0;i<=30;i++)
                tr[rt].cnt[i]=tr[ls].cnt[i]+tr[rs].cnt[i];
        }
        else
        {
            tr[rt].p=tr[ls].p;
            tr[rt].ccnt=tr[ls].ccnt+tr[rs].ccnt;
            tr[rt].op=min(tr[ls].op,tr[rs].op);
            tr[rt].sum=tr[ls].sum^tr[rs].sum;
            for(int i=0;i<=30;i++)
                tr[rt].cnt[i]=tr[ls].cnt[i]+tr[rs].cnt[i];
        }
    }
    void pushtag(int rt,int x)
    {
        if(tr[rt].p>=x)
            return ;
        if(tr[rt].ccnt&1)
            tr[rt].sum^=tr[rt].p,tr[rt].sum^=x;
        for(int i=0;i<=30;i++)
        {
            if(tr[rt].p>>i&1)
                tr[rt].cnt[i]-=tr[rt].ccnt;
            if(x>>i&1)
                tr[rt].cnt[i]+=tr[rt].ccnt;
        }
        tr[rt].p=x;
        tr[rt].tag=x;
    }
    void pushdown(int rt)
    {
        if(tr[rt].tag==-1)
            return ;
        pushtag(ls,tr[rt].tag);
        pushtag(rs,tr[rt].tag);
        tr[rt].tag=-1;
    }
    void build(int rt,int l,int r)
    {
        tr[rt].tag=-1;
        if(l==r)
        {
            tr[rt].p=a[l];
            tr[rt].sum=a[l];
            tr[rt].ccnt=1;
            tr[rt].op=1e18;
            for(int i=0;i<=30;i++)
                if(a[l]>>i&1)
                    tr[rt].cnt[i]++;
            return ;
        }
        build(ls,l,mid);
        build(rs,mid+1,r);
        pushup(rt);
    }
    void modify(int rt,int l,int r,int cl,int cr,int x)
    {
        if(tr[rt].p>=x)
            return ;
        if(l>=cl&&r<=cr&&tr[rt].op>x)
        {
            pushtag(rt,x);
            return ;
        }
        pushdown(rt);
        if(cl<=mid)
            modify(ls,l,mid,cl,cr,x);
        if(cr>mid)
            modify(rs,mid+1,r,cl,cr,x);
        pushup(rt);
    }
    int qsum(int rt,int l,int r,int cl,int cr)
    {
        if(l>=cl&&r<=cr)
            return tr[rt].sum;
        pushdown(rt);
        int sum=0;
        if(cl<=mid)
            sum^=qsum(ls,l,mid,cl,cr);
        if(cr>mid)
            sum^=qsum(rs,mid+1,r,cl,cr);
        return sum;
    }
    int qbit(int rt,int l,int r,int cl,int cr,int x)
    {
        if(l>=cl&&r<=cr)
            return tr[rt].cnt[x];
        pushdown(rt);
        int sum=0;
        if(cl<=mid)
            sum+=qbit(ls,l,mid,cl,cr,x);
        if(cr>mid)
            sum+=qbit(rs,mid+1,r,cl,cr,x);
        return sum;
    }
}tree;
int main()
{
    cin>>n>>m;
    for(int i=1;i<=n;i++)
        cin>>a[i]; 
    tree.build(1,1,n);
    while(m--)
    {
        int op,l,r,x;
        cin>>op>>l>>r>>x;
        if(op&1)
            tree.modify(1,1,n,l,r,x);
        else
        {
            int sum=tree.qsum(1,1,n,l,r),max=-1;
            sum^=x;
            for(int i=30;i>=0;i--)
            {
                if(sum>>i&1)
                {
                    max=i;
                    break;
                }
            }
            if(max==-1)
                cout<<0<<endl;
            else
                cout<<tree.qbit(1,1,n,l,r,max)+(x>>max&1)<<endl;
        }
    }
    return 0;
}