浅谈 Slope Trick:神奇的 DP 优化方法

· · 算法·理论

Slope Trick 是一种优化二维 DP 的方法,它要求 DP 数组的第二维是一个凸函数,通过记录一些关于这个凸函数的关键信息,利用数据结构进行快速转移。

可以使用 Slope Trick 优化的前提是:

我们可以将这个函数每一个斜率变化 1 的位置使用数据结构来维护(如果一个地方斜率变化 \gt 1 就存储多个相同的值),并记录整个函数最左边或最右边的解析式,就可以唯一确定这个函数了。

我们假设函数是下凸函数,且维护最右边的解析式 y = kx + b

假设有如下这个函数:

y = \begin{cases} -x - 4, & x \le -2 \\ -2, & -2 \lt x \le 2 \\ 2x - 6, & x \gt 2 \end{cases}

图像为:

那么我们就可以维护最右边的解析式,其中 k=2,b=-6 以及函数的分段点集合 \{-2,2,2\} 表示这个函数。因为在 x=2 处斜率变化了 2,所以集合中记录了两个 2

这样的函数有着一个非常好用的性质:

下面介绍几种常见的操作:

例题:P4597 序列 sequence

Slope Trick 模板题。容易发现不用考虑“修改后的数列只能出现修改前的数”这个条件,必然存在一个最优解使得该条件成立。设 dp_{i,j} 表示当前处理到第 i 个数,将这个数改为 j 的最小操作次数。那么转移方程式为:

dp_{i,j} = \min_{k \le j} dp_{i-1,k} + |a_i-j|

设函数 f_i(x) = dp_{i,x},考虑使用数学归纳法证明 f_i(x) 为下凸函数:

综上所述:\forall i \in [1,n],f_i(x) 为下凸函数。

该 dp 的转移实际上就是对上一个下凸函数取前缀 \min 再加上一个凸函数。那么我们就用一个大根堆维护 f_i(x) 的前缀 \min 就行了。此外我们在用 k,b 维护最右侧的解析式,实际上 k=0 恒成立不用维护,就维护 b 就行了。实际上最终的 b 也就是答案。

代码示例:

#include <iostream>
#include <queue>

using namespace std;

using ll = long long;

int main()
{
    ll ans = 0;
    priority_queue<int> q;
    int n; cin >> n;

    for (int i = 1; i <= n; i++)
    {
        int x; cin >> x;
        q.push(x); q.push(x);
        ans += q.top() - x;
        q.pop();
    }
    cout << ans << endl;

    return 0;
}

例题:P12074 [OOI 2025] The Arithmetic exercise

经过转化,可以得到如下方程式:(具体推导过程可以参见这篇题解)

dp_{i,j} = \begin{cases} dp_{i-1,j+1} - a_i, & j = 0 \\ max\{dp_{i-1,j-1} + a_i, dp_{i-1,j+1} - a_i\}, & 0 \lt j \lt n \\ dp_{i-1,j-1} + a_i, & j = n \end{cases}

边界情况:

dp_{0,j} = \begin{cases} 0, & j = 0 \\ -\infty, & 0 \lt j \le n \end{cases}

设函数 f_i(x) = dp_{i,x},不难发现 f_i(x) 的所有有效点(即不为 -\infty)是一个上凸函数,动态规划的转移实际上就是将这个凸函数与一个点数为 2 的凸包作闵可夫斯基和。在计算几何中求解闵可夫斯基和时,我们的凸包都是用向量表示的,类似的在这里我们就可以用一种与上一道例题不同的存储凸函数的方式:用 multiset 存贮相邻点纵坐标的差并记录点一个的纵坐标。这就可以方便的转移了,最后再枚举每一个点取最大值就得到了答案。 代码示例:

#include <iostream>
#include <functional>
#include <iterator>
#include <set>

using namespace std;

using ll = long long;
constexpr int maxn = 300000;
int a[maxn + 5];

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int t; cin >> t;
    while (t--)
    {
        int n, m; cin >> n >> m;
        for (int i = 1; i <= m; i++)
            cin >> a[m - i + 1];

        ll val = 0;       //  维护第一个点的值
        multiset<int, greater<int>> st;
        for (int i = 1; i <= m; i++)
        {
            st.insert(2 * a[i]);
            val -= a[i];
            if (i & 1)
            {
                val += *st.begin();
                st.erase(st.begin());
            }
            while (st.size() >= (i & 1 ? (n + 1) >> 1 : (n + 2) >> 1))
                st.erase(prev(st.end()));
        }

        ll ans = val;
        for (int i : st)
        {
            val += i;
            ans = max(ans, val);
        }
        cout << ans << endl;
    }

    return 0;
}