单调队列优化ST表

2021-09-19 21:38:24


这里是题目链接

萌新的第一篇题解,这道题做了好久,期间某一次提交超时了0.08s,这可太难受了。

题目分析

题目需要求任意大小为 $k$ 的三角形里的数的最大值,这其实可以看作是一种RMQ问题,又由于数据是静态的,所以很容易可以想到用ST表来求。然而萌新不太会处理倒三角形的问题,所以我是全用正三角形进行转移的:

st[k][i][j] = max(st[k - 1][i][j], *max_element(st[k - 1][i + ju] + j, st[k - 1][i + ju] + j + ju + 1));

$st_{k,i,j}$ 代表以第 $i$ 行第 $j$ 列为上顶点的大小为 $2^k$(这个 $k$ 不是题目里的 $k$,只是一个普通的变量)的三角形的最大值,而 max_element 函数是一个求地址 $[a,b)$ 之间的最大值的函数,函数会返回最大值的地址,该函数的时间复杂度近似为 $O(b-a)$。

比如说以第一行第一列为上顶点,大小为4的三角形的最大值,就等于这4个红色的大小为2的三角形的最大值的最大值(可能有点绕)。

所以,这是一个空间复杂度为 $O(n^2\log_2n)$,时间复杂度近似为 $O(n^2k\log_2n)$ 的算法。然而它在空间和时间上都会爆炸!!!

因此,我们需要优化。

优化

首先观察转移方程,可以发现 $st_{k,i,j}$ 的值只由 $st_{k-1}$ 决定,如果对背包问题比较熟悉的话就可以想到用滚动数组,这样空间复杂度就会降到 $O(n^2)$,空间问题解决了。

至于时间问题,其实我们可以发现每次转移时都会调用一次 max_element 函数,而调用一次这个函数最长需要花费我们 $O(k)$ 的时间,这导致了我们的转移方程很慢。

然而其实调用 max_element 函数的本质是为了求区间的最大值,而这个区间不但是定长的(只与 $st_{k,i,j}$ 的 $k$ 有关),而且是向右不断滑动的。滑动区间求最值问题,用单调队列优化可以把 $O(k)$ 优化成 $O(1)$,这样总体时间复杂度就降到了 $O(n^2\log_2k)$。

代码

但是不知道是不是我用了STL容器 deque 来模拟单调队列的缘故(听说STL容器很慢),不开O2优化的话,第2组测试数据全TLE,然而开了O2优化的话,最慢的一组数据只需要567ms。代码如下:

#include<bits/stdc++.h>
using namespace std;
int a[3002][3002] = { 0 }, st[3002][3002] = { 0 };
int main() {
    int n, K, maxk, ma, ju;
    long long ans = 0;
    deque<int>dui;//用双端队列模拟单调队列
    scanf("%d%d", &n, &K);
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= i; ++j) {
            scanf("%d", *(a + i) + j); st[i][j] = a[i][j];
        }
    }
    maxk = log2(K);
    for (int k = 1; k <= maxk; ++k) {
        ma = n - (1 << k) + 1; ju = 1 << k - 1;
        for (int i = 1; i <= ma; ++i) {
            for (int j = 1; j <= ju; ++j) {//初始化队列
                while (!dui.empty() && st[i + ju][dui.back()] < st[i + ju][j])dui.pop_back();
                dui.push_back(j);
            }
            for (int j = 1; j <= i; ++j) {
                while (!dui.empty() && st[i + ju][dui.back()] < st[i + ju][j + ju])dui.pop_back();
                while (!dui.empty() && dui.front() < j)dui.pop_front();
                dui.push_back(j + ju);
                st[i][j] = max(st[i][j], st[i + ju][dui.front()]);
            }
            dui.clear();//清空队列
        }
    }
    ma = n - K + 1;
    if (K == 1 << maxk) {
        for (int i = 1; i <= ma; ++i) {
            for (int j = 1; j <= i; ++j) {
                ans += st[i][j];
            }
        }
    }
    else {
        ju = 1 << maxk;
        for (int i = 1; i <= ma; ++i) {
            for (int j = 1; j <= K - ju; ++j) {
                while (!dui.empty() && st[i + K - ju][dui.back()] < st[i + K - ju][j])dui.pop_back();
                dui.push_back(j);
            }
            for (int j = 1; j <= i; ++j) {
                while (!dui.empty() && st[i + K - ju][dui.back()] < st[i + K - ju][j + K - ju])dui.pop_back();
                while (!dui.empty() && dui.front() < j)dui.pop_front();
                dui.push_back(j + K - ju);
                ans += max(st[i][j], st[i + K - ju][dui.front()]);
            }
            dui.clear();
        }
    }
    printf("%lld", ans);
    return 0;
}