【学习笔记】主席树

· · 题解

前置知识

  1. 线段树(不会的先过【线段树1】 & 【线段树2】)
  2. 知道可持久化数据结构

主席树

是一种用来查询区间静态第 k 小的数据结构

原型为线段树,简单的来讲就是开 n 棵线段树,然后区间查询时只要查询第 l-1 棵和第 r 棵做前缀和做差就行了

但是开 n 棵线段树空间复杂度太大

于是我们发现,没加入一个数,只会在一条路径上更改,如果开 n 棵线段树会有很多冗余节点(即重复节点),所以每次加入一个数,我们就多开 \log n 个节点(即路径长度)就行了,空间复杂度为 O(n\log^2 n),时间复杂度为 O(n\log n)

1. 建空树(build

首先要建一棵空树,由于我们是用前缀和的思想查找,所以在 tree[0] 的位置是一棵空树

代码如下:

void build(int &t,int l,int r)  \\t是当前节点编号
{
    int mid=(l+r)>>1;
    t=++cnt;
    if (l==r) return;
    build(ls[t],l,mid);
    build(rs[t],mid+1,r);
}

2. 插入数字(modify)

在插入数字之前,我们先要对其进行离散,然后插入其编号

123 54 78 92 193

对其进行离散后为

4 1 2 3 5

空树:

插入 4:

插入 1:

插入 2:

插入 3:

插入 5:

每次修改时遍历修改的路径,判断其左儿子还是右儿子有变化,将没变化的儿子序号设为上一棵树对应节点的对应儿子,变化儿子的序号新建一个节点,继续遍历。

代码:

int modify(int X,int l,int r)
{
    int mid=(l+r)>>1,XX=++cnt;
    ls[XX]=ls[X],rs[XX]=rs[X],sum[XX]=sum[X]+1;
    if (l==r) return XX;
    if (x<=mid) ls[XX]=modify(ls[XX],l,mid);
    else rs[XX]=modify(rs[XX],mid+1,r);
    return XX;
}

3. 查询区间第 k 大(query

对于查询区间 lr 的第 k

我们同时遍历第 l-1 棵线段树和第 r 棵线段树

定义 xx=sum[num2]-sum[num1]num2 为第 r 棵线段树当前的节点,num1 表示第 l-1 棵线段树当前对应节点,sum 表示其数字个数之差)

k\le xx 时,直接向左走

k> xx 时,向右走,同时 k-=xx

代码:

int query(int num1,int num2,int l,int r,int k)
{
    int mid=(l+r)>>1,xx=sum[ls[num2]]-sum[ls[num1]];
    if (l==r) return l;
    if (k<=xx) return query(ls[num1],ls[num2],l,mid,k);
    else return query(rs[num1],rs[num2],mid+1,r,k-xx);
}

完整代码:

#include <bits/stdc++.h>
#define N 5000010

using namespace std;

int n,m,len,cnt,sum[N],x,l,r,k,ans;
int a[N],b[N],ls[N],rs[N],tree[N];

inline int read()
{
    int x=0,tag=1;
    char c=getchar();
    for (;c<'0' || c>'9';c=getchar()) if (c=='-') tag=-1;
    for (;c>='0' && c<='9';c=getchar()) x=(x<<1)+(x<<3)+c-'0';
    return x*tag;
}

void build(int &t,int l,int r)
{
    t=++cnt;
    if (l==r) return;
    int mid=(l+r)>>1;
    build(ls[t],l,mid);
    build(rs[t],mid+1,r);
}

int modify(int x,int l,int r,int k)
{
    int mid=(l+r)>>1,xx=++cnt;
    ls[xx]=ls[x],rs[xx]=rs[x],sum[xx]=sum[x]+1;
    if (l==r) return xx;
    if (k<=mid) ls[xx]=modify(ls[xx],l,mid,k);
    else rs[xx]=modify(rs[xx],mid+1,r,k);
    return xx;
}

int query(int n1,int n2,int l,int r,int k)
{
    int mid=(l+r)>>1,xx=sum[ls[n2]]-sum[ls[n1]];
    if (l==r) return l;
    if (k<=xx) return query(ls[n1],ls[n2],l,mid,k);
    else return query(rs[n1],rs[n2],mid+1,r,k-xx);
}

int main()
{
    n=read(),m=read();
    for (int i=1;i<=n;i++)
        a[i]=b[i]=read();
    sort(b+1,b+n+1);
    len=unique(b+1,b+n+1)-b-1;
    build(tree[0],1,len);
    for (int i=1;i<=n;i++)
    {
        x=lower_bound(b+1,b+len+1,a[i])-b;
        tree[i]=modify(tree[i-1],1,len,x);
    }
    for (int i=1;i<=m;i++)
    {
        l=read(),r=read(),k=read();
        printf("%d\n",b[query(tree[l-1],tree[r],1,len,k)]);
    }
    return 0;
}

另一道模板题

可持久化数组

不用离散,改成每个节点存数字就行了

查询时只要像线段树查询一样,不需要用前缀和

代码如下:

#include <bits/stdc++.h>
#define N 50000010
using namespace std;
int n,m,cnt,opt,loc,k,val;
int tree[N],num[N],ls[N],rs[N];
void build(int &t,int l,int r)
{
    int mid=(l+r)>>1;
    t=++cnt;
    if (l==r)
    {
        scanf("%d",&num[t]);
        return;
    }
    build(ls[t],l,mid);
    build(rs[t],mid+1,r);
}

int modify(int X,int l,int r)
{
    int mid=(l+r)>>1,XX=++cnt;
    ls[XX]=ls[X],rs[XX]=rs[X];
    if (l==r)
    {
        num[XX]=val;return XX;
    }
    if (k<=mid) ls[XX]=modify(ls[XX],l,mid);
    else rs[XX]=modify(rs[XX],mid+1,r);
    return XX;
}

int query(int i,int l,int r,int k)
{
    int mid=(l+r)>>1;
    if (l==r) return num[i];
    if (k<=mid) return query(ls[i],l,mid,k);
    else return query(rs[i],mid+1,r,k);
}

int main()
{
    scanf("%d%d",&n,&m);
    build(tree[0],1,n);
    for (int i=1;i<=m;i++)
    {
        scanf("%d%d",&loc,&opt);
        if (opt==1)
        {
            scanf("%d%d",&k,&val);
            tree[i]=modify(tree[loc],1,n);
        }
        else
        {
            scanf("%d",&k);
            printf("%d\n",query(tree[loc],1,n,k));
            tree[i]=tree[loc];
        }
    }
    return 0;
}