线段树优化 DP 板子

· · 题解

线段树优化 DP 的板子。

题意

给定一个数列,求出它最长的子序列,满足相邻元素之差不大于 k

分析

对于这类最长子序列的题目,有一个 DP 套路,设 dp_i 表示以 a_i 结尾的最长合法子序列长度,则有

dp_i=\max_{j=1}^{j<i} dp_j+1

其中,|a_i-a_j| \le k

Code:

#include <bits/stdc++.h>
#define N 300005
using namespace std;

int n, k, a[N], dp[N];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    cin >> n >> k; for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i <= n; i++) {
        dp[i] = 1;
        for (int j = 1; j < i; j++) {
            if (abs(a[i] - a[j]) <= k) dp[i] = max(dp[i], dp[j] + 1);
        }
    }
    cout << *max_element(dp + 1, dp + n + 1);
    return 0;
}

喜提 TLE。

显然,上述思路的复杂度为 O(n^2),考虑优化。

考虑下面的过程:

for (int j = 1; j < i; j++) {
    if (abs(a[i] - a[j]) <= k) dp[i] = max(dp[i], dp[j] + 1);
}

什么样的 j 可以转移呢?

当且仅当 a_j \in [a_i-k, a_i+k] 时,可以进行转移。此时,题目被转化成了一个 RMQ 问题,需要单点修改,区间查询最值,想到线段树。

AC Code 如下,注意边界条件。

#include <bits/stdc++.h>
#define N 300005
using namespace std;

int n, k, a[N];

template <class T>
class SegmentTree {
#define ls (rt << 1)
#define rs (rt << 1 | 1)
private:
    int n;
    T maxn[(N << 2) + 1];

    void pushup(int rt) {
        maxn[rt] = max(maxn[ls], maxn[rs]);
    }

    void build(int l, int r, int rt) {
        if (l == r) return (void) (maxn[rt] = 0);
        int mid = (l + r) >> 1;
        build(l, mid, ls), build(mid + 1, r, rs);
        pushup(rt);
    }

    void update(int i, const T c, int l, int r, int rt) {
        if (l == r) return (void) (maxn[rt] = max(maxn[rt], c));
        int mid = (l + r) >> 1;
        if (i <= mid) update(i, c, l, mid, ls);
        else update(i, c, mid + 1, r, rs);
        pushup(rt);
    }

    T query(int tl, int tr, int l, int r, int rt) {
        if (tl <= l && r <= tr) return maxn[rt];
        int mid = (l + r) >> 1; T res(0);
        if (tl <= mid) res = max(res, query(tl, tr, l, mid, ls));
        if (tr > mid) res = max(res, query(tl, tr, mid + 1, r, rs));
        return res;
    }

public:
    explicit SegmentTree(int _n) : n(_n) {build(1, n, 1);}
    void update(int i, const T c) {update(i, c, 1, n, 1);}
    T query(int l, int r) {return query(l, r, 1, n, 1);}

#undef ls
#undef rs
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    cin >> n >> k; for (int i = 1; i <= n; i++) cin >> a[i];
    int m = *max_element(a + 1, a + n + 1);
    SegmentTree<int> sgt(m);
    for (int i = 1; i <= n; i++) {
        sgt.update(a[i], sgt.query(max(1, a[i] - k), min(m, a[i] + k)) + 1);
    }
    cout << sgt.query(0, m);
    return 0;
}

线段树优化后,复杂度为 O(n \log w),其中 wa_i 的值域。可以通过本题。