RMQ&LCA学习笔记

JS_TZ_ZHR

2021-02-22 22:31:10

Solution

要讲的内容:笛卡尔树,树的欧拉序,ST 表,分块 RMQ,$\pm1$RMQ。 ## 笛卡尔树 笛卡尔树是二叉树的一种,每个节点有两个关键字,分别满足堆和二叉搜索树的性质。一个序列上的笛卡尔树的两个关键字就是值和下标(其实可以用 Treap 来理解),如下(堆是小根堆): ![](https://oi-wiki.org/ds/images/cartesian-tree1.png) 图片引用自 OI-wiki。 通过笛卡尔树的定义,我们可以得到一个建树方式,对于 $[l,r]$ 的笛卡尔树,先找到最值的位置 $p$,$p$ 即为这棵树的根。在将 $[l,p-1]$ 的笛卡尔树作为左子树,$[p+1,r]$ 的笛卡尔树作为右子树。这个建树方式表明一个节点的左/右子树的下标只能在这个节点的下标的左 / 右。 笛卡尔树有一个美妙的性质,就是对于两个下标为 $l$ 和 $r$ 的节点,$[l,r]$ 的最值就是 $l$ 和 $r$ 在笛卡尔树上的最近公共祖先的值。 我们来试着证明这个性质(下面 $l$ 节点指的是下标为 $l$ 的节点)。 定理 $1$:$[l,r]$ 的最值就是 $l$ 和 $r$ 在笛卡尔树上的最近公共祖先的值。 证:若 $l$ 节点和 $r$ 节点都在他们 LCA 的左 / 右子树,那么这与LCA的定义矛盾。显然,$l$ 节点和 $r$ 节点的 LCA 是其子树中的最值并且涵盖了 $[l,r]$ 这个区间,于是这个性质就得证了。这在 RMQ 转 LCA 问题时有很大的用处。 关于笛卡尔树的构建,可以参考 [P5854](https://www.luogu.com.cn/problem/P5854),具体就是维护右链,对于新插入的值,找到右链上第值一个小于它的节点,变成它的右儿子,它原来的右儿子变为它的左儿子。没有这个节点,直接变为根,原来的树变成它的左子树。 核心代码: ```cpp for(int i=1;i<=n;i++){ k=top; while(k>0&&a[s[k]]>a[i])k--; if(k)rs[s[k]]=i; if(k<top)ls[i]=s[k+1]; s[++k]=i; top=k; } ``` 笛卡尔树其实可以从根暴力往下跳直到找到在 $[l,r]$ 之间的节点。这样做是 $O(nm)$ 的,但随机数据下比较快。如 [P3793](https://www.luogu.com.cn/problem/P3793)。 代码: ```cpp #include<bits/stdc++.h> #define N 20000005 using namespace std; int n,a[N],s[N],top,ls[N],rs[N],k,x,root,m,_s,l,r; long long ans; int query(int rt,int l,int r) { while(rt<l||rt>r) { if(rt<l)rt=rs[rt]; else rt=ls[rt]; } return rt; } namespace GenHelper { unsigned z1,z2,z3,z4,b; unsigned rand_() { b=((z1<<6)^z1)>>13; z1=((z1&4294967294U)<<18)^b; b=((z2<<2)^z2)>>27; z2=((z2&4294967288U)<<2)^b; b=((z3<<13)^z3)>>21; z3=((z3&4294967280U)<<7)^b; b=((z4<<3)^z4)>>12; z4=((z4&4294967168U)<<13)^b; return (z1^z2^z3^z4); } } void srand(unsigned x) { using namespace GenHelper; z1=x; z2=(~x)^0x233333333U; z3=x^0x1234598766U; z4=(~x)+51; } int read() { using namespace GenHelper; int a=rand_()&32767; int b=rand_()&32767; return a*32768+b; } int main() { scanf("%d%d%d",&n,&m,&_s); srand(_s); for(int i=1; i<=n; i++)a[i]=-read(); for(int i=1; i<=n; i++) { k=top; while(k>0&&a[s[k]]>a[i])k--; if(k)rs[s[k]]=i; if(k<top)ls[i]=s[k+1]; s[++k]=i; top=k; } root=s[1]; while(m--) { l=read()%n+1,r=read()%n+1; if(l>r)swap(l,r); ans-=a[query(root,l,r)]; } cout<<ans<<endl; } ``` ## 树的欧拉序 树的欧拉序就是在 DFS 时,无论是访问到还是回溯时,都记录一遍点。这么说可能不理解,放代码更直观一点。 ```cpp void dfs(int now,int fa) { e[++cnt]=now;//e[i]表示欧拉序里第i个节点是什么 fi[now]=cnt;//fi[now]表示now节点第一次出现的位置 for(int i=ve[now].size()-1; i>=0; i--) { int y=ve[now][i]; if(y==fa)continue; dep[y]=dep[now]+1; dfs(y,now); e[++cnt]=now;//回溯时记录 } } ``` 从这里可以看出,树的欧拉序长点数加边数 $=2n-1$。 理解了树的欧拉序之后,我们来看它有什么性质可以利用。 定理 $2$:$i$ 和 $j$ 两个节点的 LCA 就是 $i$ 第一次出现的位置和 $j$ 第一次出现的位置之间深度最小的那个点。 证:当 $i=j$ 时显然。 若 $i$ 和 $j$ 中有一个是它们的 LCA ,那么在进入不是 LCA 的那个节点时到第一次进入 LCA,只会遍历到 LCA 子树中的节点里。所以深度最小的点就是 LCA。 如果 $i$ 和 $j$ 中没有一个是它们的 LCA,那么它们肯定在 LCA的不同儿子的子树内。遍历完一个会回溯到 LCA,并且因为下面还有节点要遍历,不会回到祖先。于是定理得证。 这样,LCA 和 RMQ 这两个问题就变得等价了。 欧拉序还有一个性质,就是序列中相邻节点的深度差为 $\pm1$,这个很显然,会在以后讲到。 ## ST 表 ST表的核心思想是对于每个位置 $p$,预处理 $[p,p+2^0-1]$,$[p,p+2^1-1]$,$[p,p+2^2-1]$,$[p,p+2^3-1]$... 的最值。然后对于每个 $[l,r]$ 的询问,输出 $\max/\min(ST[l][k],ST[r-2^k+1][k])$ ($ST[p][k]$ 表示 $\max\limits_{l \le i \le p+2^k-1}/\min \limits_{l \le i \le p+2^k-1} a_i$ )。这里 $k=\log_2 (r-l+1)$。 代码: ```cpp #include<bits/stdc++.h> using namespace std; int n,m,a,b; int Max[100005][21]; inline int query(int l,int r) { register int k=log2(r-l+1); return max(Max[l][k],Max[r-(1<<k)+1][k]); } inline int read() { register int x=0; register char ch=getchar(); while(ch<'0'||ch>'9')ch=getchar(); while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch-'0'),ch=getchar(); return x; } int main() { scanf("%d%d",&n,&m); for(register int i=1; i<=n; i++)Max[i][0]=read(); for(register int j=1; j<=21; j++) for(register int i=1; i+(1<<j)-1<=n; i++) Max[i][j]=max(Max[i][j-1],Max[i+(1<<(j-1))][j-1]); for(register int i=1; i<=m; i++) { a=read(),b=read(); printf("%d\n",query(a,b)); } } ``` ## 分块 RMQ 分块 RMQ 是计算 RMQ 的一种思想,其一般与 ST 表等数据结构配合使用。这里先讲 $\log_2n$ 和 $\sqrt{n}$ 两种块长。最后的 +1-1RMQ 用的则是 $\frac{\log_2n}{2}$ 的块长。 先看 $\log_2n$( Four Russian 算法,有人叫它四毛子)。我们对于整块,预处理整块的最值。再直接用 ST 表。时间复杂度 $O(\frac{n}{\log n} \log(\frac{n}{\log n}))=O(n)$。对于块内,也用 ST 表预处理,时间复杂度 $O(\frac{n}{\log n} \log \log n)=O(n \log \log n)$,查询时对整块和零散快分别 $O(1)$ 查询。 [P3793](https://www.luogu.com.cn/problem/P3793) 代码(被卡了,只能过 $n=1.5*10^7$): ```cpp #include<bits/stdc++.h> #define N 20001005 using namespace std; int n,m,s,a[1000005],l,r,B,ST1[900000][26],ST2[900000][26][5],cnt; long long ans; namespace GenHelper { unsigned z1,z2,z3,z4,b; unsigned rand_() { b=((z1<<6)^z1)>>13; z1=((z1&4294967294U)<<18)^b; b=((z2<<2)^z2)>>27; z2=((z2&4294967288U)<<2)^b; b=((z3<<13)^z3)>>21; z3=((z3&4294967280U)<<7)^b; b=((z4<<3)^z4)>>12; z4=((z4&4294967168U)<<13)^b; return (z1^z2^z3^z4); } } void srand(unsigned x) { using namespace GenHelper; z1=x; z2=(~x)^0x233333333U; z3=x^0x1234598766U; z4=(~x)+51; } int read() { using namespace GenHelper; int a=rand_()&32767; int b=rand_()&32767; return a*32768+b; } int query1(int l,int r) { if(l>r)return 0; int k=a[r-l+1]; return max(ST1[l][k],ST1[r-(1<<k)+1][k]); } int query2(int p,int l,int r) { int k=a[r-l+1]; return max(ST2[p][l][k],ST2[p][r-(1<<k)+1][k]); } int query(int l,int r) { int L=(l-1)/B+1,R=(r-1)/B+1,res=0; if(L==R)return query2(L,((l-1)%B)+1,((r-1)%B)+1); res=query1(L+1,R-1); res=max(res,query2(L,((l-1)%B)+1,B)); res=max(res,query2(R,1,((r-1)%B)+1)); return res; } int main() { cin>>n>>m>>s; B=log2(n); cnt=(n-1)/B+1; srand(s); for(int i=1; i<=n; i++) { a[0]=read(); int now=(i-1)/B+1; if(a[0]>ST1[now][0])ST1[now][0]=a[0]; ST2[now][i-now*B+B][0]=a[0]; } for(int j=1; (1<<j)<=cnt; j++) for(int i=1; i<=cnt; i++) ST1[i][j]=max(ST1[i][j-1],ST1[i+(1<<(j-1))][j-1]); for(int i=1; i<=cnt; i++) for(int j=1; (1<<j)<=B; j++) for(int k=1; k<=B; k++) ST2[i][k][j]=max(ST2[i][k][j-1],ST2[i][k+(1<<(j-1))][j-1]); a[0]=a[1]=0; for(int i=2; i<=1000000; i++)a[i]=a[i>>1]+1; while(m--) { l=read()%n+1,r=read()%n+1; if(l>r)swap(l,r); ans+=query(l,r); } cout<<ans<<endl; } ``` 其实这种做法可以反复分块,但这样没有什么用,常数还大。 $\sqrt{n}$ 块长的就比较玄学了,对于整块,直接暴力 $O(\sqrt{n}^2)$ 预处理,块内只预处理前后缀最值,这样,$l$ 和 $r$ 不在同一块的情形就在 $O(1)$ 的时间解决了。容易发现,$l$ 和 $r$ 在同一块的概率大概是 $\frac{1}{\sqrt{n}}$ 的,$m$ 次询问这种情况运行 $O(\frac{m}{\sqrt{n}}\sqrt{n})=O(m)$,于是这种做法在随机数据下就是线性的,但上界其实是 $O(m \sqrt{n})$ 的。这样也可以通过 [P3793](https://www.luogu.com.cn/problem/P3793)。 代码: ```cpp #include<bits/stdc++.h> #define max Max using namespace std; int n,m,s,a[20000005],max1[2][5005][5005],max2[5005][5005]; int block,begin[1000005],end[1000005],l,r,num; long long ans; namespace GenHelper { unsigned z1,z2,z3,z4,b; unsigned rand_() { b=((z1<<6)^z1)>>13; z1=((z1&4294967294U)<<18)^b; b=((z2<<2)^z2)>>27; z2=((z2&4294967288U)<<2)^b; b=((z3<<13)^z3)>>21; z3=((z3&4294967280U)<<7)^b; b=((z4<<3)^z4)>>12; z4=((z4&4294967168U)<<13)^b; return (z1^z2^z3^z4); } } inline void srand(unsigned x) { using namespace GenHelper; z1=x; z2=(~x)^0x233333333U; z3=x^0x1234598766U; z4=(~x)+51; } inline int read() { using namespace GenHelper; int a=rand_()&32767; int b=rand_()&32767; return a*32768+b; } inline int max(int a,int b) { return a>b?a:b; } void build() { block=sqrt(n); for(int i=1; i<=n; i+=block) { begin[++num]=i,end[num]=min(n,i+block-1); for(int j=i; j<=end[num]; j++) max1[0][num][j-i+1]=max(max1[0][num][j-i],a[j]); for(int j=end[num]; j>=i; j--) max1[1][num][end[num]-j+1]=max(max1[1][num][end[num]-j],a[j]); } for(int i=1; i<=num; i++) { for(int j=i; j<=num; j++) { max2[i][j]=max2[i][j-1]; max2[i][j]=max(max2[i][j],max1[0][j][end[j]-begin[j]+1]); } } } int query(int l,int r){ int numl=(l/block)+(l%block!=0); int numr=(r/block)+(r%block!=0); int res=0; if(numl==numr){ for(int i=l;i<=r;i++)res=max(res,a[i]); return res; } res=max(max1[1][numl][end[numl]-l+1],max1[0][numr][r-begin[numr]+1]); return max(max2[numl+1][numr-1],res); } int main() { scanf("%d%d%d",&n,&m,&s); srand(s); for(int i=1; i<=n; i++)a[i]=read(); build(); while(m--){ l=read()%n+1,r=read()%n+1; if(l>r)swap(l,r); ans+=(unsigned long long)query(l,r); } cout<<ans; } ``` ## $\pm1$RMQ 重点来了。 上文提到,欧拉序中相邻的两个节点的深度值差为 $\pm1$。考虑利用这个性质。这种情况下可以使用分块 RMQ ,对于一个块,只有 $2^{B-1}$ 种本质不同的可能( $B$ 是块长,有$B-1$个相邻的数对,且有$\pm1$两种可能)。于是块内可以在 $O(2^{B-1}B^2)$ 处理每种块的所有区间的最值的位置。容易发现,$B$ 取 $\frac{\log_2n}{2}$ 时可以在线性时间内解决。块内是 $O(\sqrt{n} \log^2n)$ 的复杂度。而整块,直接用 Four Russian 算法,还是线性的,于是这个问题就解决了。 最终代码: RMQ: ```cpp #include<bits/stdc++.h> #define N 2000005 using namespace std; int n,m,a[N],root,s[N],top,e[N],cnt,fi[N],sum[N],dep[N],B=8; int M[2000][12][12],pre[2000][12],suf[2000][12],Min[50005],ST[50005][25]; int typ[N],Sum[N],lg[N]; struct T { int val,l,r; } t[N]; vector<int>ve[N]; void build1() { for(int i=1; i<=n; i++) { t[i].val=a[i],s[top+1]=0; while(top&&t[s[top]].val<a[i])top--; t[i].l=s[top+1]; if(top)t[s[top]].r=i; s[++top]=i; } root=s[1]; } void build2() { int num[1005]; while(cnt&7)cnt++; for(int i=0; i<(1<<(B-1)); i++) { num[1]=0; for(int j=0; j<B-1; j++) { if(i&(1<<j))num[j+2]=num[j+1]+1; else num[j+2]=num[j+1]-1; } for(int j=1; j<=B; j++) { M[i][j][j]=j; for(int k=j+1; k<=B; k++) M[i][j][k]=num[k]<num[M[i][j][k-1]]?k:M[i][j][k-1]; } for(int j=1; j<=B; j++)pre[i][j]=M[i][1][j],suf[i][j]=M[i][j][B]; Min[i]=M[i][1][B]; } for(int i=1; i<=(cnt>>3); i++) { typ[i]=0; for(int j=(i<<3)-B+2,t=0; j<=(i<<3); j++,t++) if(sum[j]-sum[j-1]>=0)typ[i]|=(1<<t); Sum[i]=sum[Min[typ[i]]+((i-1)<<3)]; } for(int i=1; i<=(cnt>>3); i++)ST[i][0]=i; int len=lg[cnt]-2; for(int j=1; j<=len; j++) { for(int i=1; i<=(cnt>>3)-(1<<j)+1; i++) { int k=i+(1<<(j-1)); ST[i][j]=Sum[ST[k][j-1]]<Sum[ST[i][j-1]]?ST[k][j-1]:ST[i][j-1]; } } } int Brmq(int l,int r) { int k=lg[r-l+1]; r-=(1<<k)-1; return Sum[ST[l][k]]<Sum[ST[r][k]]?ST[l][k]:ST[r][k]; } void dfs(int now) { e[++cnt]=now; fi[now]=cnt; for(int i=ve[now].size()-1; i>=0; i--) { int y=ve[now][i]; dep[y]=dep[now]+1; dfs(y); e[++cnt]=now; } } int query(int l,int r) { if(l>r)swap(l,r); int L=(l-1)/B+1,R=(r-1)/B+1; if(L==R) { int pos=(L-1)<<3; return pos+M[typ[L]][l-pos][r-pos]; } else if(L+1==R) { int pl=(L-1)<<3,pr=(R-1)<<3; return sum[suf[typ[L]][l-pl]+pl]<sum[pre[typ[R]][r-pr]+pr]?suf[typ[L]][l-pl]+pl:pre[typ[R]][r-pr]+pr; } else { int pl=(L-1)<<3,pr=(R-1)<<3; int t1=Brmq(L+1,R-1),t2=sum[suf[typ[L]][l-pl]+pl]<sum[pre[typ[R]][r-pr]+pr]?suf[typ[L]][l-pl]+pl:pre[typ[R]][r-pr]+pr; t1=((t1-1)<<3)+Min[typ[t1]]; return sum[t1]<sum[t2]?t1:t2; } } int LCA(int x,int y) { return e[query(fi[x],fi[y])]; } namespace IO { const int MAXSIZE = 1 << 20; char buf[MAXSIZE], *p1, *p2; #define gc() \ (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2) \ ? EOF \ : *p1++) inline int read() { int x = 0, f = 1; char c = gc(); while (!isdigit(c)) { if (c == '-') f = -1; c = gc(); } while (isdigit(c)) x = x * 10 + (c ^ 48), c = gc(); return x * f; } char pbuf[1 << 20], *pp = pbuf; inline void push(const char &c) { if (pp - pbuf == 1 << 20) fwrite(pbuf, 1, pp - pbuf, stdout), pp = pbuf; *pp++ = c; } inline void write(int x) { static int sta[35]; int top=0; do { sta[++top] = x % 10, x /= 10; } while (x); while (top) push(sta[top--]+'0'); push('\n'); } } using namespace IO; int main() { n=read(),m=read(); for(int i=2; i<=n*2+1; i++)lg[i]=lg[i>>1]+1; for(int i=1; i<=n; i++)a[i]=read(); build1(); for(int i=1; i<=n; i++) { if(t[i].l)ve[i].push_back(t[i].l); if(t[i].r)ve[i].push_back(t[i].r); } dfs(root); for(int i=1; i<=cnt; i++)sum[i]=dep[e[i]]; build2(); int l,r; while(m--) { l=read(),r=read(); write(a[LCA(l,r)]); } fwrite(pbuf, 1, pp - pbuf, stdout); } ``` LCA: ```cpp #include<bits/stdc++.h> #define N 2000005 using namespace std; int n,m,root,s[N],top,e[N],cnt,fi[N],sum[N],dep[N],B; int M[10005][12][12],pre[10005][12],suf[10005][12],Min[10005],ST[200005][25]; int typ[N],Sum[N],lg[N]; vector<int>ve[N]; void build() { int num[2005]; while(cnt%B)cnt++; for(int i=0; i<(1<<(B-1)); i++) { num[1]=0; for(int j=0; j<B-1; j++) { if(i&(1<<j))num[j+2]=num[j+1]+1; else num[j+2]=num[j+1]-1; } for(int j=1; j<=B; j++) { M[i][j][j]=j; for(int k=j+1; k<=B; k++) M[i][j][k]=num[k]<num[M[i][j][k-1]]?k:M[i][j][k-1]; } for(int j=1; j<=B; j++)pre[i][j]=M[i][1][j],suf[i][j]=M[i][j][B]; Min[i]=M[i][1][B]; } for(int i=1; i<=cnt/B; i++) { typ[i]=0; for(int j=(i*B)-B+2,t=0; j<=(i*B); j++,t++) if(sum[j]-sum[j-1]>=0)typ[i]|=(1<<t); Sum[i]=sum[Min[typ[i]]+((i-1)*B)]; } for(int i=1; i<=(cnt/B); i++)ST[i][0]=i; int len=lg[cnt]-2; for(int j=1; j<=len; j++) { for(int i=1; i<=(cnt/B)-(1<<j)+1; i++) { int k=i+(1<<(j-1)); ST[i][j]=Sum[ST[k][j-1]]<Sum[ST[i][j-1]]?ST[k][j-1]:ST[i][j-1]; } } } int Brmq(int l,int r) { int k=lg[r-l+1]; r-=(1<<k)-1; return Sum[ST[l][k]]<Sum[ST[r][k]]?ST[l][k]:ST[r][k]; } void dfs(int now,int fa) { e[++cnt]=now; fi[now]=cnt; for(int i=ve[now].size()-1; i>=0; i--) { int y=ve[now][i]; if(y==fa)continue; dep[y]=dep[now]+1; dfs(y,now); e[++cnt]=now; } } int query(int l,int r) { if(l>r)swap(l,r); int L=(l-1)/B+1,R=(r-1)/B+1; if(L==R) { int pos=(L-1)*B; return pos+M[typ[L]][l-pos][r-pos]; } else if(L+1==R) { int pl=(L-1)*B,pr=(R-1)*B; return sum[suf[typ[L]][l-pl]+pl]<sum[pre[typ[R]][r-pr]+pr]?suf[typ[L]][l-pl]+pl:pre[typ[R]][r-pr]+pr; } else { int pl=(L-1)*B,pr=(R-1)*B; int t1=Brmq(L+1,R-1),t2=sum[suf[typ[L]][l-pl]+pl]<sum[pre[typ[R]][r-pr]+pr]?suf[typ[L]][l-pl]+pl:pre[typ[R]][r-pr]+pr; t1=((t1-1)*B)+Min[typ[t1]]; return sum[t1]<sum[t2]?t1:t2; } } int LCA(int x,int y) { return e[query(fi[x],fi[y])]; } namespace IO { const int MAXSIZE = 1 << 20; char buf[MAXSIZE], *p1, *p2; #define gc() \ (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2) \ ? EOF \ : *p1++) inline int read() { int x = 0, f = 1; char c = gc(); while (!isdigit(c)) { if (c == '-') f = -1; c = gc(); } while (isdigit(c)) x = x * 10 + (c ^ 48), c = gc(); return x * f; } char pbuf[1 << 20], *pp = pbuf; inline void push(const char &c) { if (pp - pbuf == 1 << 20) fwrite(pbuf, 1, pp - pbuf, stdout), pp = pbuf; *pp++ = c; } inline void write(int x) { static int sta[35]; int top=0; do { sta[++top] = x % 10, x /= 10; } while (x); while (top) push(sta[top--]+'0'); push('\n'); } } using namespace IO; int main() { n=read(),m=read(),root=read(); B=log2(2.0*n)/2.0; for(int i=2; i<=n*2+1; i++)lg[i]=lg[i>>1]+1; for(int i=1; i<n; i++) { int u,v; u=read(),v=read(); ve[u].push_back(v); ve[v].push_back(u); } dfs(root,0); for(int i=1; i<=cnt; i++)sum[i]=dep[e[i]]; build(); int l,r; while(m--) { l=read(),r=read(); write(LCA(l,r)); } fwrite(pbuf, 1, pp - pbuf, stdout); } ``` ### UPD:改正了一个代码错误。