slope trick

· · 题解

比较优秀的一道 slope trick 题目。是一个练习带权 slope trick 的题,并不需要手写数据结构加速。

Rosabel 大佬的线段树解法太长了,像我这样的刚学 OI 的萌新完全写不了,这里介绍一种很玄学的 std::map 维护的方法。

我的方法大部分来自官方题解,一部分来自 Rosabel 大佬的题解,如果部分推理过程与官方题解雷同,属正常现象。

不会维护拐点的 slope trick 的人看起来可能比较费劲。

开始讲解

先不思考如何记录过程,下面先说一下转移方法。

考虑朴素 DP,定义 f_{i,j} 为吃掉前 i 个商人后,位置在 j 的最小代价,范围为 i\in[0,n]j\in[\min x,\max x],初始化为 f_{0,j}=c\times|j|,答案为 \min f_{n,j}

转移为 dp_{i,j}=\min dp_{i-1,k}+c\times|j-k|+d\times|x_i-j|,其中 c\times|j-k| 表示移动 You,d\times|x_i-j| 表示移动商人。

下面开始考虑如何维护。

这是一个基础的带权 slope trick 题目,后面讲得会很慢,

两步分开处理,我们用经典的 slope trick 维护方法,左右各开一个数据结构维护拐点信息,平常我们都是使用 std::priority_queue 或者 std::multiset 维护,但是这一题中拐点斜率变化量可以达到 O(n\times c) 级别,因此改为 std::map。(我是不会告诉你我不会自己写数据结构的)

我们用两个 std::map 维护拐点,斜率为 0 的段(或点)左右分别为 LR,每一对 (Key,Value) 代表函数 fKey 处斜率变化了 Value

上面的过程中,第一步形如让斜率小于 -c 的都变为 -c,斜率大于 c 的都变为 c,这对应在左右两只 std::map 上的维护就是不断删或减小每一侧的 std::mapValue,直到每一侧的 Value 和均为 c

第二步形如插入一个绝对值的倍数,这是维护拐点的 slope trick 很 static 的操作,将这玩意分为 d(x_i-j)_+d(j-x_i)_+。考虑到插入一堆基础函数后,DP 数组 f 值取 \min f 的点可能无法在点 \min\{R\}\max\{L\} 上了,我们就需仔细地维护 \min f。这里我们就必须一点一点维护 \min f,形象的说,就是一点一点插入那 2d 个基础函数,慢慢地转这个,每次根据到底是加入 (x_i-j)_+ 还是 (j-x_i)_+ 只插入 \max\{L\}.Value\min\{R\}.Value 个。

因此我们就可以写出一个最基础的代码了。

#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, c, d, Lcnt, Rcnt, x[1000005], minf;
map<int, int> L, R;
signed main() {
    cin >> n >> c >> d;
    L[0] = R[0] = c;
    Lcnt = Rcnt = c;
    for (int i = 1; i <= n; i++) {
        cin >> x[i];
        while (Lcnt > c) { // 削掉 < -c 的斜率
            if (L.begin()->second > Lcnt - c) {
                L.begin()->second -= Lcnt - c;
                Lcnt = c;
            } else {
                Lcnt -= L.begin()->second;
                L.erase(L.begin());
            }
        }
        while (Rcnt > c) { // 削掉 > c 的斜率
            if ((--R.end())->second > Rcnt - c) {
                (--R.end())->second -= Rcnt - c;
                Rcnt = c;
            } else {
                Rcnt -= (--R.end())->second;
                R.erase(--R.end());
            }
        }
        int add_k = d;
        if (R.begin()->first >= x[i]) { // 插入 (j-x[i])+
            L[x[i]] += add_k;
        } else {
            int td = add_k;
            while (td) {
                if (R.begin()->second > td) { // 当前的 min{R}.Value 已经够大了,將剩下的 td 个 (j-x[i])+ 全部插入,minf 还可以在 min{R} 取值
                    minf += max(0ll, x[i] - R.begin()->first) * td;
                    L[R.begin()->first] += td;
                    R.begin()->second -= td;
                    R[x[i]] += td;
                    td = 0;
                } else {
                    int num = R.begin()->second; // 当前的 min{R}.Value 不够大,先放进去 min{R}.Value 个 (j-x[i])+,minf 可以在 min{R} 取值
                    minf += max(0ll, x[i] - R.begin()->first) * num;
                    td -= num;
                    L[R.begin()->first] += num; // 随后是维护拐点的 slope trick 的标准代码
                    R.erase(R.begin());
                    R[x[i]] += num; // 这次只放 min{R}.Value 个,并让新的新的 minf 可以在新的 min{R}.Value 上取到
                }
            }
        }
        if ((--L.end())->first <= x[i]) { // 插入 (x[i]-j)+,同理
            R[x[i]] += add_k;
        } else {
            int td = add_k;
            while (td) {
                map<int, int>::iterator it = --L.end();
                if (it->second > td) {
                    minf += max(0ll, it->first - x[i]) * td;
                    R[it->first] += td;
                    it->second -= td;
                    L[x[i]] += td;
                    td = 0;
                } else {
                    int num = it->second;
                    minf += max(0ll, it->first - x[i]) * num;
                    td -= num;
                    R[it->first] += num;
                    L.erase(it);
                    L[x[i]] += num;
                }
            }
        }
        Lcnt += add_k; // 维护一下 L 和 R 的和
        Rcnt += add_k;
    }
    cout << minf << endl;
    return 0;
}

但是这个会 TLE 九个点,为什么?因为复杂度不正确,如果这个 c 恰好卡出了一堆 Value 很小的节点然后 x 不断在 -\inf+\inf 之间跳,我们会把这些 Value 特别小的折点不断地在 LR 间移动。

怎么优化?我们可以只插入 \min\{d,2\times c\} 个函数,可以证明,这样正确性是没影响的(因为插入多了也会被有关 c 的限制砍掉)。

我们将代码中的 add_k 设为 \min\{d,2\times c\} 就可以正确而快速地算出答案。

#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, c, d, Lcnt, Rcnt, x[1000005], minf;
map<int, int> L, R;
signed main() {
    cin >> n >> c >> d;
    L[0] = R[0] = c;
    Lcnt = Rcnt = c;
    for (int i = 1; i <= n; i++) {
        cin >> x[i];
        while (Lcnt > c) { // 削掉 < -c 的斜率
            if (L.begin()->second > Lcnt - c) {
                L.begin()->second -= Lcnt - c;
                Lcnt = c;
            } else {
                Lcnt -= L.begin()->second;
                L.erase(L.begin());
            }
        }
        while (Rcnt > c) { // 削掉 > c 的斜率
            if ((--R.end())->second > Rcnt - c) {
                (--R.end())->second -= Rcnt - c;
                Rcnt = c;
            } else {
                Rcnt -= (--R.end())->second;
                R.erase(--R.end());
            }
        }
        int add_k = min(d, 2 * c);
        if (R.begin()->first >= x[i]) { // 插入 (j-x[i])+
            L[x[i]] += add_k;
        } else {
            int td = add_k;
            while (td) {
                if (R.begin()->second > td) { // 当前的 min{R}.Value 已经够大了,將剩下的 td 个 (j-x[i])+ 全部插入,minf 还可以在 min{R} 取值
                    minf += max(0ll, x[i] - R.begin()->first) * td;
                    L[R.begin()->first] += td;
                    R.begin()->second -= td;
                    R[x[i]] += td;
                    td = 0;
                } else {
                    int num = R.begin()->second; // 当前的 min{R}.Value 不够大,先放进去 min{R}.Value 个 (j-x[i])+,minf 可以在 min{R} 取值
                    minf += max(0ll, x[i] - R.begin()->first) * num;
                    td -= num;
                    L[R.begin()->first] += num; // 随后是维护拐点的 slope trick 的标准代码
                    R.erase(R.begin());
                    R[x[i]] += num; // 这次只放 min{R}.Value 个,并让新的新的 minf 可以在新的 min{R}.Value 上取到
                }
            }
        }
        if ((--L.end())->first <= x[i]) { // 插入 (x[i]-j)+,同理
            R[x[i]] += add_k;
        } else {
            int td = add_k;
            while (td) {
                map<int, int>::iterator it = --L.end();
                if (it->second > td) {
                    minf += max(0ll, it->first - x[i]) * td;
                    R[it->first] += td;
                    it->second -= td;
                    L[x[i]] += td;
                    td = 0;
                } else {
                    int num = it->second;
                    minf += max(0ll, it->first - x[i]) * num;
                    td -= num;
                    R[it->first] += num;
                    L.erase(it);
                    L[x[i]] += num;
                }
            }
        }
        Lcnt += add_k; // 维护一下 L 和 R 的和
        Rcnt += add_k;
    }
    cout << minf << endl;
    return 0;
}

先别走,我们还没有记录方案,记录方案的方法很简单,我们每一次砍完 c,留下的 \min\{L\}\max\{R\} 外的点一定是从 \min\{L\}\max\{R\} 转移的,因此记录每一步砍完 c 后的 \min\{L\}\max\{R\} 即可。

输出时,我们知道 \min f 取值的点([\max\{L\},\min\{R\}] 内的点都行),然后我们倒着计算方案,如果超过了记录的 \max\{R\},就让人退到 \max\{R\},如果超过了记录的 \min\{L\},就让人退到 \min\{L\}

完整代码,91 行,2KB 出头:

#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, c, d, Lcnt, Rcnt, current_min[1000005], current_max[1000005], x[1000005], minf;
map<int, int> L, R;
signed main() {
    cin >> n >> c >> d;
    L[0] = R[0] = c;
    Lcnt = Rcnt = c;
    for (int i = 1; i <= n; i++) {
        cin >> x[i];
        while (Lcnt > c) { // 削掉 < -c 的斜率
            if (L.begin()->second > Lcnt - c) {
                L.begin()->second -= Lcnt - c;
                Lcnt = c;
            } else {
                Lcnt -= L.begin()->second;
                L.erase(L.begin());
            }
        }
        while (Rcnt > c) { // 削掉 > c 的斜率
            if ((--R.end())->second > Rcnt - c) {
                (--R.end())->second -= Rcnt - c;
                Rcnt = c;
            } else {
                Rcnt -= (--R.end())->second;
                R.erase(--R.end());
            }
        }
        current_min[i] = L.begin()->first;
        current_max[i] = (--R.end())->first;
        int add_k = min(d, 2 * c);
        if (R.begin()->first >= x[i]) { // 插入 (j-x[i])+
            L[x[i]] += add_k;
        } else {
            int td = add_k;
            while (td) {
                if (R.begin()->second > td) { // 当前的 min{R}.Value 已经够大了,將剩下的 td 个 (j-x[i])+ 全部插入,minf 还可以在 min{R} 取值
                    minf += max(0ll, x[i] - R.begin()->first) * td;
                    L[R.begin()->first] += td;
                    R.begin()->second -= td;
                    R[x[i]] += td;
                    td = 0;
                } else {
                    int num = R.begin()->second; // 当前的 min{R}.Value 不够大,先放进去 min{R}.Value 个 (j-x[i])+,minf 可以在 min{R} 取值
                    minf += max(0ll, x[i] - R.begin()->first) * num;
                    td -= num;
                    L[R.begin()->first] += num; // 随后是维护拐点的 slope trick 的标准代码
                    R.erase(R.begin());
                    R[x[i]] += num; // 这次只放 min{R}.Value 个,并让新的新的 minf 可以在新的 min{R}.Value 上取到
                }
            }
        }
        if ((--L.end())->first <= x[i]) { // 插入 (x[i]-j)+,同理
            R[x[i]] += add_k;
        } else {
            int td = add_k;
            while (td) {
                map<int, int>::iterator it = --L.end();
                if (it->second > td) {
                    minf += max(0ll, it->first - x[i]) * td;
                    R[it->first] += td;
                    it->second -= td;
                    L[x[i]] += td;
                    td = 0;
                } else {
                    int num = it->second;
                    minf += max(0ll, it->first - x[i]) * num;
                    td -= num;
                    R[it->first] += num;
                    L.erase(it);
                    L[x[i]] += num;
                }
            }
        }
        Lcnt += add_k; // 维护一下 L 和 R 的和
        Rcnt += add_k;
    }
    cout << minf << endl;
    int ans_pos = R.begin()->first;
    vector<int> ans_poss;
    for (int i = n; i >= 1; i--) {
        ans_poss.push_back(ans_pos);
        ans_pos = max(min(current_max[i], ans_pos), current_min[i]);
    }
    reverse(ans_poss.begin(), ans_poss.end());
    for (int i : ans_poss) {
        cout << i << " ";
    }
    return 0;
}