浅谈树上主席树

· · 算法·理论

主席树这个东西有很多比较好的性质,比如可以直接拿两个版本相减来得到一个区间的权值线段树,可以比较容易地维护一些排名问题之类的

那么我们把问题转换成树上问题,比如查询两点之间的第 k 大,怎么做呢?

非常容易想到,利用加减得到链的权值线段树。

首先我们可以想想怎么在树上建主席树来方便地加减,最简单的方法就是直接把某一点的版本设为其到根上的路径的权值线段树,这样就可以容易凑出一条路径。

我们求一下 \text{LCA},设点 x 版本的 权值线段树为 F_x,最近公共祖先为 t,然后路径 x\to y 的权值线段树就是 F_x+F_y-F_{t}-F_{fa_t }

P2633 Count on a tree

板子题,把链的权值线段树弄出来直接找第 k 大即可。

时间复杂度为 \mathcal O(n\log n)

#include<bits/stdc++.h>
#define LL long long
#define LF long double
#define pLL pair<LL,LL>
#define pb push_back
//#define fir first
//#define sec second
using namespace std;
const LL inf=3e9;
const LL N=2e5+5;
const LL M=6e7+5;
const LL K=20;
//const LL mod;
//const LF eps;
//const LL P;
struct node
{
    LL l,r,sz;
}t[M];
vector<LL>v[N];
LL n,tot,rt[N],f[N][K+5],dep[N],a[N];
void ins(LL &rt,LL l,LL r,LL x)
{
    t[++tot]=t[rt];
    rt=tot;
    if(l==r)
    {
        //  cout<<rt<<' '<<l<<' '<<x<<endl;
        t[rt].sz++;
        return;
    }
    LL mid=(l+r)/2;
    if(x<=mid)ins(t[rt].l,l,mid,x);
    else ins(t[rt].r,mid+1,r,x);
    t[rt].sz=t[t[rt].l].sz+t[t[rt].r].sz;
    //  cout<<l<<' '<<r<<' '<<t[rt].sz<<endl;

}
LL query(LL rt,LL rt2,LL rt3,LL rt4,LL l,LL r,LL x)
{
    LL mid=(l+r)/2;
    LL sz=t[t[rt].l].sz+t[t[rt3].l].sz-t[t[rt2].l].sz-t[t[rt4].l].sz;
    if(l==r)return l;
    if(x<=sz)return query(t[rt].l,t[rt2].l,t[rt3].l,t[rt4].l,l,mid,x); 
    return query(t[rt].r,t[rt2].r,t[rt3].r,t[rt4].r,mid+1,r,x-sz);
}
void dfs(LL x,LL fa)
{
    dep[x]=dep[fa]+1,f[x][0]=fa;
    for(int i=1;i<=K;i++)f[x][i]=f[f[x][i-1]][i-1];
    rt[x]=rt[fa];
    ins(rt[x],1,inf,a[x]);
    for(LL i:v[x])
    {
        if(i==fa)continue;
        dfs(i,x);
    }
}
inline LL lca(LL x,LL y)
{
    if(dep[x]<dep[y])swap(x,y);
    for(int i=K;i>=0;i--)if(dep[f[x][i]]>=dep[y])x=f[x][i];
    if(x==y)return x;
    for(int i=K;i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
    return f[x][0];
}
LL q,lst;
int main()
{
    scanf("%lld%lld",&n,&q);
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&a[i]);
    }
    for(int i=1;i<=n-1;i++)
    {
        LL x,y;
        scanf("%lld%lld",&x,&y);
        v[x].pb(y),v[y].pb(x);
    }
    dfs(1,0);
    while(q--)
    {
        LL x,y,k;
        scanf("%lld%lld%lld",&x,&y,&k);
        x^=lst;
        LL t=lca(x,y),ft=f[t][0];
        cerr<<t<<endl;
        LL ans=query(rt[x],rt[t],rt[y],rt[ft],1,inf,k);
        lst=ans;
        printf("%lld\n",ans);
    }
    return 0;
}
//RP++

P3302 [SDOI2013] 森林

这题也差不多,注意到加边操作,其实不难搞,这题谁当根都差不多,每次加边暴力重新处理一整棵树的信息即可。

但这样肯定是不行的,显然需要启发式合并,暴力处理较小的一棵树即可。

启发式合并的时间复杂度也要算上,所以是 \mathcal O(n\log n\log V)

#include<bits/stdc++.h>
#define LL int
#define LF long double
#define pLL pair<LL,LL>
#define pb push_back
//#define fir first
//#define sec second
using namespace std;
const LL inf=1e9;
const LL N=1e5+5;
const LL M=3e7+5;
const LL K=20;
//const LL mod;
//const LF eps;
//const LL P;
struct node
{
    LL l,r,sz;
}t[M];
vector<LL>v[N];
LL n,tot,Fa[N],sz[N],rt[N],f[N][K+5],sum[N],dep[N],a[N];
LL find(LL x)
{
    if(Fa[x]==x)return x;
    return Fa[x]=find(Fa[x]);
}
void ins(LL &rt,LL l,LL r,LL x)
{
    t[++tot]=t[rt];
    rt=tot;
    if(l==r)
    {
        //  cout<<rt<<' '<<l<<' '<<x<<endl;
        t[rt].sz++;
        return;
    }
    LL mid=l+r>>1;
    if(x<=mid)ins(t[rt].l,l,mid,x);
    else ins(t[rt].r,mid+1,r,x);
    t[rt].sz=t[t[rt].l].sz+t[t[rt].r].sz;
    //  cout<<l<<' '<<r<<' '<<t[rt].sz<<endl;

}
LL query(LL rt,LL rt2,LL rt3,LL rt4,LL l,LL r,LL x)
{
    LL mid=(l+r)/2;
    LL sz=t[t[rt].l].sz+t[t[rt3].l].sz-t[t[rt2].l].sz-t[t[rt4].l].sz;
//  cout<<l<<' '<<r<<' '<<x<<' '<<sz<<endl;
    if(l==r)return l;
    if(x<=sz)return query(t[rt].l,t[rt2].l,t[rt3].l,t[rt4].l,l,mid,x); 
    return query(t[rt].r,t[rt2].r,t[rt3].r,t[rt4].r,mid+1,r,x-sz);
}
void dfs(LL x,LL fa)
{
    dep[x]=dep[fa]+1,f[x][0]=fa;
    for(int i=1;i<=K;i++)f[x][i]=f[f[x][i-1]][i-1];
    rt[x]=rt[fa];
    sum[x]=sum[fa]+a[x];
    ins(rt[x],0,inf,a[x]);
    for(LL i:v[x])
    {
        if(i==fa)continue;
        dfs(i,x);
    }
}
inline LL lca(LL x,LL y)
{
    if(dep[x]<dep[y])swap(x,y);
    for(int i=K;i>=0;i--)if(dep[f[x][i]]>=dep[y])x=f[x][i];
    if(x==y)return x;
    for(int i=K;i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
    return f[x][0];
}
LL id,q,m,lst;
int main()
{
    scanf("%d",&id);
    scanf("%d%d%d",&n,&m,&q);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        Fa[i]=i,sz[i]=1,sum[i]=a[i],dep[i]=1;
        ins(rt[i],0,inf,a[i]);
    }
    for(int i=1;i<=m;i++)
    {
        LL x,y;
        scanf("%d%d",&x,&y);
        LL fx=find(x),fy=find(y);
        if(sz[fx]<sz[fy])swap(x,y),swap(fx,fy);
        Fa[fy]=fx,sz[fx]+=sz[fy];
        dfs(y,x);
        v[x].pb(y),v[y].pb(x);
    }
    while(q--)
    {
        char op[5];
        LL x,y;
        scanf("%s%d%d",op,&x,&y);
        x^=lst,y^=lst;
        if(op[0]=='L')
        {
            LL fx=find(x),fy=find(y);
            if(sz[fx]<sz[fy])swap(x,y),swap(fx,fy);
            Fa[fy]=fx,sz[fx]+=sz[fy];
            dfs(y,x);
            v[x].pb(y),v[y].pb(x);
        }
        else
        {
            int k;
            scanf("%d",&k);
            k^=lst;
            LL t=lca(x,y),ft=f[t][0];
            LL ans=query(rt[x],rt[ft],rt[y],rt[t],0,inf,k);
            lst=ans;
            printf("%d\n",ans);
        }
    }       

    return 0;
}
//RP++