P7936 题解

· · 题解

线性dp && 维护前缀和

题意

求从点 1 到点 n 的一条路径,并打印路径。

思路

  1. 线性dp

    • 无后效性:每次移动,横纵坐标只会变大或不变。
    • 最优子结构: 显然。
  2. 前缀和

  3. 排序

    • 保证dp顺序。
    • 维护前缀和。

实现

于是有:

f[i] = max(f[j] - cost + w[i])

代码

// https://www.luogu.com.cn/problem/P7936
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
using ll = long long;
const int inf = 0x3f3f3f3f;
#define rf(i, n) for (int i = 1; i <= (n); ++i)
int rad(); // read int
const int max_size = 5 + 1e5;
const int maxn = 5 + 3e5;

int n, cost;
struct Pos {
    int x, y, w, id;
    bool operator<(const Pos &other) const {
        if (x != other.x) return x < other.x;
        return y < other.y;
    }
} pos[maxn];

int pre[maxn];              // pre[i]: 到达 node i  的最优解的前一步
int f[maxn];                // 到达节点 i 所剩能量的最大值
int prex[maxn], prey[maxn]; //
// f[i] = max(f[j] - cost + w[i]), j to i

void show(int now, int dep = 0) {
    if (now == 0) {
        printf("%d\n", dep);
        return;
    };
    show(pre[now], dep + 1);
    printf("%d %d\n", pos[now].x, pos[now].y);
}

void update(int now, int from) {
    if (from == 0 || f[from] < cost) return;
    if (f[now] < f[from] - cost + pos[now].w) {
        f[now] = f[from] - cost + pos[now].w;
        pre[now] = from;
    }
}

int main() {
    n = rad(), cost = rad();
    rf(i, n) {
        pos[i].x = rad(), pos[i].y = rad(), pos[i].w = rad();
        pos[i].id = i;
    }
    sort(pos + 1, pos + 1 + n);
    int start = 1, finish = 1; // 找到起点,终点
    while (pos[start].id != 1)
        start++;
    while (pos[finish].id != n)
        finish++;

    f[start] = pos[start].w, prex[pos[start].x] = start, prey[pos[start].y] = start;
    for (int i = start + 1; i <= finish; ++i) {
        update(i, prex[pos[i].x]);
        update(i, prey[pos[i].y]);
        if (f[i] > f[prex[pos[i].x]]) prex[pos[i].x] = i;
        if (f[i] > f[prey[pos[i].y]]) prey[pos[i].y] = i;
    }
    printf("%d\n", f[finish]);
    show(finish);
}

int rad() {
    int back = 0, ch = 0, neg = 0;
    for (; ch < '0' || ch > '9'; ch = getchar())
        neg = ch == '-';
    for (; ch >= '0' && ch <= '9'; ch = getchar())
        back = (back << 1) + (back << 3) + (ch & 0xf);
    return (back ^ -neg) + neg;
}