题解:P10784 【MX-J1-T4】『FLA - III』Wrestle

· · 题解

题意

n 条红色线段和 m 条蓝色线段,并且所有同颜色的线段两两间没有交集。红色线段有权重,蓝色没有。选出若干蓝色线段满足:

求满足这样情况下,最多的红蓝线段交集整点数是多少。

思路

首先注意到线段值域较大,考虑将所有红蓝线段离散化后重新映射下标。

然后发现蓝色线段的数量 m \le 5000 ,权重的限制 k \le 5000 ,对于每条蓝色线段来说只有选或不选,这就和我们学过的01背包非常相似了!我们可以把权重视为限制,设 dp[i][j][k] 表示以第 i 条线段结尾,权重限制为 j ,当 k = 0 时代表不选第 i 条,k = 1 时选择第 i 条的最大交集整点数。于是我们有转移方程:

dp[i][j][k] = \begin{cases} \max(dp[i - 1][j][0], dp[i - 1][j][1]), & k = 0 \\ \max(dp[l][j][0], dp[l][j][1]) + P_i, & k = 1 \end{cases}

其中 l 表示 < i 且不与 i 冲突(也就是和 i 不存在公共的红色线段)的最后一条线段。由于题目给的线段有非常好的性质,那么我们可以将线段按任意端点从小到大排序,就可以使用二分查找了!这样这道题的主要思路就讲完了,具体细节来看代码吧。

代码

#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
#define ll long long
const int N = 200010, M = 5005;
struct Red {
    int l, r, w;
    bool operator< (const Red& other) const {
        return l < other.l;
    }
} red[N], blue[M];
ll dp[M][M][2] = { 0 };
int dis[(N + M) << 1], dcnt = 0;
int Kpres[N];
ll Lpres[N];
int get_ind(int val) {
    int l = 1, r = dcnt, mid;
    while (l < r) {
        mid = l + r >> 1;
        if (dis[mid] >= val) r = mid;
        else l = mid + 1;
    }
    return r;
}
int n, m, k;
//找到跟第 i 条线段冲突的第一个下标
int get_rgL(int lef) {
    int l = 0, r = n + 1, mid;
    while (l < r) {
        mid = l + r >> 1;
        if (red[mid].r >= lef) r = mid;
        else l = mid + 1;
    }
    return r;
}
//找到跟第 i 条线段冲突的最后一个下标
int get_rgR(int rig) {
    int l = 0, r = n + 1, mid;
    while (l < r) {
        mid = l + r + 1 >> 1;
        if (red[mid].l <= rig) l = mid;
        else r = mid - 1;
    }
    return l;
}
int unique() {
    int ret = 0, las = -1;
    for (int i = 1; i <= dcnt; i++) {
        if (dis[i] != las)
            dis[++ret] = dis[i];
        las = dis[i];
    }
    return ret;
}
//两条线段的交集整点数
inline int get_inter(int l1, int r1, int l2, int r2) {
    return max(0, min(r1, r2) - max(l1, l2) + 1);
}
int main() {
    scanf("%d%d%d", &n, &m, &k);
    for (int i = 1; i <= n; i++) {
        int l, r, w;
        scanf("%d%d%d", &l, &r, &w);
        red[i] = { l, r, w };
        dis[++dcnt] = l, dis[++dcnt] = r;
    }
    for (int i = 1; i <= m; i++) {
        int l, r;
        scanf("%d%d", &l, &r);
        blue[i] = { l, r, 0 };
        dis[++dcnt] = l, dis[++dcnt] = r;
    }
    sort(dis + 1, dis + 1 + dcnt);
    dcnt = unique();
    for (int i = 1; i <= n; i++)
        red[i].l = get_ind(red[i].l), red[i].r = get_ind(red[i].r);
    sort(red + 1, red + 1 + n);
    Kpres[0] = Lpres[0] = 0;
    for (int i = 1; i <= n; i++) {
        // Kpres 是权重的前缀和,Lpres 是长度前缀和
        Kpres[i] = Kpres[i - 1] + red[i].w;
        Lpres[i] = Lpres[i - 1] + (dis[red[i].r] - dis[red[i].l] + 1);
    }
    for (int i = 1; i <= m; i++)
        blue[i].l = get_ind(blue[i].l), blue[i].r = get_ind(blue[i].r);
    sort(blue + 1, blue + 1 + m);
    //放置哨兵,防止二分进行错误的命中
    red[0] = blue[0] = { -2, -1, 0 }, red[n + 1] = blue[m + 1] = { 10000000, 10000000 + 1, 0 };
    ll ans = 0;
    for (int i = 1; i <= m; i++) {
        int l = get_rgL(blue[i].l), r = get_rgR(blue[i].r);
        if (l > r) memcpy(dp[i], dp[i - 1], sizeof dp[i]);
        else {
            // p 是价值,w 是重量
            ll p = 0, w = Kpres[r] - Kpres[l - 1];
            if (l + 1 <= r - 1) p += Lpres[r - 1] - Lpres[l];
            if (l != r) p += get_inter(dis[blue[i].l], dis[blue[i].r],
                dis[red[r].l], dis[red[r].r]);
            p += get_inter(dis[blue[i].l], dis[blue[i].r],
                dis[red[l].l], dis[red[l].r]);
            int L = 0, R = m + 1, mid;
            //右端点大于等于最左边一项的右端点 
            while (L < R) {
                mid = L + R >> 1;
                if (blue[mid].r >= red[l].l) R = mid;
                else L = mid + 1;
            }
            r = R - 1;
            for (int j = 0; j <= k; j++) {
                dp[i][j][0] = max(dp[i - 1][j][0], dp[i - 1][j][1]);
                dp[i][j][1] = dp[i][j][0];
                if (j >= w)
                    dp[i][j][1] = max(dp[i][j][1], max(dp[r][j - w][0], dp[r][j - w][1]) + p);
                ans = max(ans, max(dp[i][j][0], dp[i][j][1]));
            }
        }
    }
    printf("%lld", ans);
    return 0;
}