题解:P7564 [JOISC 2021] ボディーガード (Day3)

· · 题解

套路拼拼拼的题?不过挺有意思的。

首先给题目加上时间维,在二维平面上去考虑每个操作。下文钦定横轴为时间维,纵轴为位置维。那么每个顾客都可以用 (t_i,a_i)\rightarrow (t_i+|a_i-b_i|,b_i) 这个线段表示。保镖同理。

容易发现所有的线段的斜率均为 \pm 1,考虑给整个坐标系逆时针旋转 45 \degree,即 (x,y)\rightarrow (x+y,x-y),那么所有的线段均与 x 轴或 y 轴平行。但这个坐标系的坐标范围太大了,考虑将顾客的坐标离散化下来,坐标系的大小就是 \mathcal{O}(n^2) 的,我们可以在上面做很多事情。

给每个线段一个权值,就是 \operatorname{len} \times c_i,注意因为我们旋转了坐标系,根据勾股定理,新的长度即为现在的 \operatorname{len} \times \frac{c_i}{2},那么现在问题变为给你一个点,你可以向右或向上走,权值为经过的线段的权值和。

但是注意这个点不一定是整点,因为保镖没有加入离散化的过程,如果是整点直接 dp 一遍每个点开始走的答案表即为 \mathcal{O}(n^2)。考虑不是整点的情况。那么可以先向右再向上,或者先向上再向右,权值即为向上或向右的那一段。

只讨论一种情况,另外一种同理,钦定先向右再向上。对 x 扫描线,然后从上往下扫 y,先将询问挂在整点上方便处理,然后询问的答案就是一个整点的答案加上 \operatorname{len} \times c_i 的最大值,这是一个一次函数的形式,斜率为 c,截距为整点的答案。直接用李超树维护即可。

总时间复杂度 \mathcal{O}(n^2 \log V + q \log V)

#include <bits/stdc++.h>
#define int long long
#define rd read()
using namespace std;
inline int read() {
    int x = 0; char ch = getchar();
    while (ch < '0' || ch > '9')ch = getchar();
    while (ch >= '0' && ch <= '9')x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
    return x;
}
const int N = 5605, M = 3e6 + 5;
struct Node {int x0, y0, x1, y1, v;} p[N];
struct Query {int x, y, id;} t[M];
struct Line {int k, b;} L[N];
#define gt(id, x) L[id].k * x + L[id].b
int n, q, bx[N << 1], by[N << 1], cnt1, cnt2, ans[M], w1[N][N], w2[N][N], rt, tot, f[N][N], cnt;
vector<int> F[N], g[N];
struct LCT {
    int tr[N * 20], lc[N * 20], rc[N * 20];
    inline void init() {rt = tot = 0;}
    inline void ins(int &k, int l, int r, int id) {
        if (!k)return k = ++tot, tr[k] = id, lc[k] = rc[k] = 0, void();
        int mid = l + r >> 1;
        if (gt(id, mid) > gt(tr[k], mid))swap(id, tr[k]);
        if (gt(id, l) > gt(tr[k], l))ins(lc[k], l, mid, id);
        else if (gt(id, r) > gt(tr[k], r))ins(rc[k], mid + 1, r, id);
    }   
    inline int ask(int k, int l, int r, int p) {
        if (!k)return 0;
        if (l == r)return gt(tr[k], p);
        int mid = l + r >> 1;
        return max(gt(tr[k], p), mid >= p ? ask(lc[k], l, mid, p) : ask(rc[k], mid + 1, r, p));
    }
} T;
inline int dfs(int x, int y) {
    if (x > cnt1 || y > cnt2)return 0;      
    if (~f[x][y])return f[x][y];
    int v1 = 0, v2 = 0;
    if (x < cnt1)v1 = dfs(x + 1, y) + w2[x][y] * (bx[x + 1] - bx[x]);
    if (y < cnt2)v2 = dfs(x, y + 1) + w1[x][y] * (by[y + 1] - by[y]);
    return f[x][y] = max(v1, v2);
}
inline bool cmp1(int x, int y) {return t[x].y > t[y].y;}
inline bool cmp2(int x, int y) {return t[x].x > t[y].x;} 
signed main() {
    n = rd, q = rd; memset(f, -1, sizeof f);
    for (int i = 1; i <= n; ++i) {
        int tim = rd, a = rd, b = rd, c = rd;
        p[i] = {tim, a, tim + abs(a - b), b};
        int X0 = p[i].x0, Y0 = p[i].y0, X1 = p[i].x1, Y1 = p[i].y1;
        p[i] = {X0 + Y0, X0 - Y0, X1 + Y1, X1 - Y1, c >> 1};
        bx[++cnt1] = p[i].x0, bx[++cnt1] = p[i].x1;
        by[++cnt2] = p[i].y0, by[++cnt2] = p[i].y1;
    }
    sort(bx + 1, bx + 1 + cnt1), sort(by + 1, by + 1 + cnt2);
    cnt1 = unique(bx + 1, bx + 1 + cnt1) - bx - 1, cnt2 = unique(by + 1, by + 1 + cnt2) - by - 1;
    for (int i = 1; i <= n; ++i)
    p[i] = {lower_bound(bx + 1, bx + 1 + cnt1, p[i].x0) - bx, lower_bound(by + 1, by + 1 + cnt2, p[i].y0) - by, lower_bound(bx + 1, bx + 1 + cnt1, p[i].x1) - bx, lower_bound(by + 1, by + 1 + cnt2, p[i].y1) - by, p[i].v};
    for (int i = 1; i <= q; ++i) {
        t[i] = {rd, rd, i};
        int x = t[i].x, y = t[i].y;
        t[i].x = x + y, t[i].y = x - y;
        x = lower_bound(bx + 1, bx + 1 + cnt1, t[i].x) - bx, y = lower_bound(by + 1, by + 1 + cnt2, t[i].y) - by;
        F[x].push_back(i), g[y].push_back(i);
    }
    for (int i = 1; i <= n; ++i) {  
        if (p[i].x0 == p[i].x1) for (int j = p[i].y0; j < p[i].y1; ++j)w1[p[i].x0][j] = max(w1[p[i].x0][j], p[i].v);
        else for (int j = p[i].x0; j < p[i].x1; ++j)w2[j][p[i].y0] = max(w2[j][p[i].y0], p[i].v);
    }
    L[0] = {0, INT_MIN};
    for (int i = 1; i <= cnt1; ++i) {
        sort(F[i].begin(), F[i].end(), cmp1);
        T.init(); int las = cnt2;
        cnt = 0;
        for (auto j : F[i]) {
            while (las && by[las] >= t[j].y) {
                L[++cnt] = {w2[i - 1][las], dfs(i, las)}; 
                T.ins(rt, 0, 1e9, cnt);
                --las; 
            } 
            ans[t[j].id] = T.ask(rt, 0, 1e9, bx[i] - t[j].x); 
        }
    }
    for (int i = 1; i <= cnt2; ++i) {
        sort(g[i].begin(), g[i].end(), cmp2);
        T.init(); int las = cnt1;
        cnt = 0;
        for (auto j : g[i]) {
            while (las && bx[las] >= t[j].x) {
                L[++cnt] = {w1[las][i - 1], dfs(las, i)};
                T.ins(rt, 0, 1e9, cnt);
                --las;
            } 
            ans[t[j].id] = max(ans[t[j].id], T.ask(rt, 0, 1e9, by[i] - t[j].y));
        }
    }
    for (int i = 1; i <= q; ++i)cout << ans[i] << '\n';
    return 0;
}