P12478 题解

· · 题解

Problem Link

题目大意

给定 a_1\sim a_{2^n},满足 a_i\in[0,2^n),定义 f(l,r) 表示 \max_x\mathrm{mex}(a_l\oplus x,a_{l+1}\oplus x,\dots,a_r\oplus x)

支持 q 次询问一个区间的 f 值,或者所有子区间的 f 值之和。

数据范围:n\le 18,q\le 10^6

思路分析

首先刻画 f,把所有元素建 Trie,计算每个子树的 f_p,如果有一个子树是满的,则 f_p=f_{ls}+f_{rs},否则 f_p=\max (f_{ls},f_{rs})

但这个做法不能很好的刻画 f(l,r)

注意到 f 的本质就是选择 Trie 树上的一个点,然后把他到根的链删掉,剩余的满子树大小之和。

那么如果 a 是一个排列,我们就可以扫描线,在每个子树被填满的时候更新他的兄弟子树内每个叶子的权值。

注意到每个叶子的权值都是一个 n 段的分段函数,因此总的更新次数是 \mathcal O(n^22^n) 的,再用一个线段树维护即可。

进一步你把每个叶子的分段函数先归并起来,这样的更新次数是 \mathcal O(n2^n) 的。

接下来回到原问题,此时每个子树会被多次填满,我们就不能每次暴力更新兄弟子树的分段函数。

考虑一个比较方便的维护分段函数的方法,对每个子树建一棵刚才的 Trie 树,并且维护最大出现位置最小的一个,然后把这个元素删掉,重复此过程,每次取出 f_{rt} 即可。

然后当一个子树填满的时候,设此前子树内最大出现位置最小的元素时刻为 t',而现在被更新为 t,那么我们只关心其兄弟子树在 [t',t] 范围内的分段函数。

那么我们对每个子树建 Trie 树维护 dp,然后我们不断弹出兄弟子树最小时刻在 [t',t] 范围内的元素,就能得到这个范围的分段函数。

很显然得到的总段数不超过每个 Trie 树上插入操作的次数,那么操作次数就是 \mathcal O(n2^n) 级别的。

由于 f 具有单调性,那么区间 chkmax 相当于区间赋值,用线段树支持区间赋值区间历史和即可。

时间复杂度 \mathcal O(n^22^n+nq)

代码呈现

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN=(1<<18)+5,inf=1e9;
int ty,n,m,q;
struct Trie {
    int N,lo,*f,*s;
    array <int,2> *mn;
    void psu(int p) {
        mn[p]=min(mn[p<<1],mn[p<<1|1]),f[p]=max(f[p<<1],f[p<<1|1]);
        if(f[p]==s[p]/2) f[p]=f[p<<1]+f[p<<1|1];
    }
    void init(int l,int r) {
        N=r-l+1,lo=l,f=new int[2*N],s=new int[2*N],mn=new array<int,2>[2*N];
        for(int i=N;i<2*N;++i) s[i]=1,f[i]=0,mn[i]={inf,i};
        for(int i=N-1;i;--i) psu(i),s[i]=s[i<<1]+s[i<<1|1];
    }
    void upd(int p,int t) {
        p=p-lo+N,f[p]=1,mn[p]={t,p};
        for(p>>=1;p;p>>=1) psu(p);
    }
    void pop() {
        int p=mn[1][1]; f[p]=0,mn[p]={inf,p};
        for(p>>=1;p;p>>=1) psu(p);
    }
}   tr[MAXN*2];
struct SegmentTree {
    int mn[MAXN*2],mx[MAXN*2],tg[MAXN*2],len[MAXN*2],ct[MAXN*2];
    ll su[MAXN*2],hs[MAXN*2],ht[MAXN*2];
    //ct -> tg -> ht
    void adt(int p,int k) { hs[p]+=1ll*k*su[p],(~tg[p]?ht[p]+=1ll*tg[p]*k:ct[p]+=k); }
    void cov(int p,int k) { mn[p]=mx[p]=tg[p]=k,su[p]=1ll*len[p]*k; }
    void adh(int p,ll h) { hs[p]+=h*len[p],ht[p]+=h; }
    void psd(int p) {
        if(ct[p]) adt(p<<1,ct[p]),adt(p<<1|1,ct[p]),ct[p]=0;
        if(~tg[p]) cov(p<<1,tg[p]),cov(p<<1|1,tg[p]),tg[p]=-1;
        if(ht[p]) adh(p<<1,ht[p]),adh(p<<1|1,ht[p]),ht[p]=0;
    }
    void psu(int p) {
        su[p]=su[p<<1]+su[p<<1|1],hs[p]=hs[p<<1]+hs[p<<1|1];
        mn[p]=min(mn[p<<1],mn[p<<1|1]),mx[p]=max(mx[p<<1],mx[p<<1|1]);
    }
    void init(int l=0,int r=n-1,int p=1) {
        tg[p]=-1,len[p]=r-l+1,tr[p].init(l,r);
        if(l==r) return ;
        int mid=(l+r)>>1;
        init(l,mid,p<<1),init(mid+1,r,p<<1|1);
    }
    void upd(int ul,int ur,int v,int l=0,int r=n-1,int p=1) {
        if(mn[p]>=v) return ;
        if(ul<=l&&r<=ur&&mx[p]<=v) return cov(p,v);
        int mid=(l+r)>>1; psd(p);
        if(ul<=mid) upd(ul,ur,v,l,mid,p<<1);
        if(mid<ur) upd(ul,ur,v,mid+1,r,p<<1|1);
        psu(p);
    }
    ll qhs(int ul,int ur,int l=0,int r=n-1,int p=1) {
        if(ul<=l&&r<=ur) return hs[p];
        int mid=(l+r)>>1; ll s=0; psd(p);
        if(ul<=mid) s+=qhs(ul,ur,l,mid,p<<1);
        if(mid<ur) s+=qhs(ul,ur,mid+1,r,p<<1|1);
        return s;
    }
    int qv(int u,int l=0,int r=n-1,int p=1) {
        if(l==r) return su[p];
        int mid=(l+r)>>1; psd(p);
        return u<=mid?qv(u,l,mid,p<<1):qv(u,mid+1,r,p<<1|1);
    }
}   T;
int a[MAXN],mn[MAXN*2];
ll ans[MAXN*4];
vector <array<int,2>> qy[MAXN];
struct info {
    int d,w,t,o;
};
void upd(int x,int t) {
    vector <info> op;
    op.push_back({0,1,t,0});
    for(int d=0;d<m;++d) {
        int p=(x+n)>>d,lst=mn[p];
        tr[p].upd(x,t),mn[p]=(d?min(mn[p<<1],mn[p<<1|1]):t);
        if(mn[p^1]>=0) op.push_back({d+1,1<<d,mn[p^1],1});
        if(mn[p]==lst) continue;
        for(;tr[p^1].mn[1][0]<=mn[p];tr[p^1].pop()) {
            op.push_back({d+1,tr[p^1].f[1]+(1<<d),tr[p^1].mn[1][0],0});
        }
        if(tr[p^1].mn[1][0]<=t) op.push_back({d+1,tr[p^1].f[1]+(1<<d),mn[p],0});
    }
    sort(op.begin(),op.end(),[&](auto i,auto j){ return i.t>j.t; });
    static int f[20],w[20];
    memset(f,0,sizeof(f)),memset(w,0,sizeof(w));
    for(auto e:op) {
        if(e.o) for(int i=0;i<e.d;++i) w[i]+=e.w;
        else f[e.d]=max(f[e.d],e.w);
        int z=0;
        for(int i=0;i<=m;++i) z=max(z,f[i]+w[i]);
        T.upd(0,e.t,z);
    }
}
signed main() {
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    cin>>m>>q>>ty,n=1<<m,T.init();
    for(int i=0;i<n;++i) cin>>a[i];
    for(int i=1,l,r;i<=q;++i) cin>>l>>r,qy[r-1].push_back({l-1,i});
    memset(mn,-0x3f,sizeof(mn));
    for(int i=0;i<n;++i) {
        upd(a[i],i),T.adt(1,1);
        for(auto o:qy[i]) ans[o[1]]=(ty==1?T.qv(o[0]):T.qhs(o[0],i));
    }
    for(int i=1;i<=q;++i) cout<<ans[i]<<"\n";
    return 0;
}