P3865 【模板】ST 表 && RMQ 问题题解

· · 题解

UPD:更正时间复杂度。

UPD:修改了对于预处理中公式的证明。

UPD:更换了更美观的图片。

前置知识

如下图所示为左边界为 1 的情况:

而假设我们当前要求区间 [a,b] 的最大值,那么就可以求出一个 k,使得 2^k\leq b-a+12^{k+1}\geq b-a+1 。然后返回 \max(st_{a,k},st_{b-2^k+1,k}) 即可。

如下图为查询区间 [2,7] 的最大值的情况:

正确性证明

首先证明 k 一定存在且唯一。设区间长度为 len,那么存在 2^p=len,令 k=\lfloor p \rfloor,显然此时 2^k\leq len2^{k+1}\geq len。由此也可以得出计算 k 的方式,即 k=\lfloor\log_2 len\rfloor

然后显然这样算一定会覆盖区间 [a,b] 中的所有元素,且重叠部分不影响答案。

实现

预处理

可以从后往前,用类似 dp 的方式预处理。转移方程显然为:

st_{i,j}=\max(st_{i,j-1},st_{i+2^{j-1},j-1})

两重循环即可。

同时由于 c++ 中 log2(x) 函数的复杂度为 O(\log x)(一说接近 O(1)),要想实现真正的 O(1) 查询,还要预处理出 1n 的对数整数值,记为 lg_i。该值可以由下面这个公式估算:

lg_i=lg_{\lfloor\frac i 2\rfloor}+1

特别地,lg_0=-1,这是为了方便递推,实际查询时不会用到它。

证明

不会证明被审核员打回并嘲讽了。所以来简单证明一下。

\lfloor\log_2 x\rfloor 就相当于求其二进制最高位是第几位。那么显然不断右移直到仅剩一位,右移的次数即为答案。每次除以 2 并向下取整,就相当于右移一位。由于前面一定已经求出过此时的结果,直接调用并加 1 即可。

代码

void init() {
    lg[0] = -1;
    for (int i = 1; i < N; i++)
        lg[i] = lg[i / 2] + 1;
    for (int i = n; i >= 1; i--) {
        for (int j = 0; i + (1 << j) -1 <= n; j++) {
            if (j == 0)
                st[i][j] = a[i];
            else
                st[i][j] = max(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
        }
    }
}

查询

首先计算得出 k,然后返回 \max(st_{a,k},st_{b-2^k+1,k}) 即可。

代码

int query(int a, int b) {
    int len = (b - a + 1);
    int k = lg[len];
    return max(st[a][k], st[b - (1ll << k) +1][k]);
}

复杂度分析

预处理时间复杂度 O(n\log n),查询复杂度 O(1)。总复杂度 O(n\log n)

空间复杂度显然 O(n\log n)

完整代码

本题时限较紧,需要用快读卡常。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define dd double
inline ll read() {
    ll x = 0, f = 1;
    char ch;
    while (((ch = getchar()) < 48 || ch > 57) && ch != EOF)if (ch == '-')f = -1;
    if (ch == EOF)x = EOF;
    while (ch >= 48 && ch <= 57)x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}
const ll N = 1e5+9, logN = 30;
ll st[N][logN];
ll n, m;
ll lg[N];
ll a[N];
void init() {
    lg[0] = -1;
    for (int i = 1; i < N; i++)
        lg[i] = lg[i / 2] + 1;
    for (int i = n; i >= 1; i--) {
        for (int j = 0; i + (1 << j) -1 <= n; j++) {
            if (j == 0)
                st[i][j] = a[i];
            else
                st[i][j] = max(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
        }
    }
}
int query(int a, int b) {
    int len = (b - a + 1);
    int k = lg[len];
    return max(st[a][k], st[b - (1ll << k) +1][k]);
}
int main() {
    n = read();
    m = read();
    for (int i = 1; i <= n; i++)
        a[i] = read();
    init();
    for (int i = 1; i <= m; i++) {
        ll l = read(), r = read();
        cout << query(l, r);
        putchar('\n');
    }
    return 0;
}