[题解] P10656 「ROI 2017 Day 2」学习轨迹

· · 题解

考虑暴力怎么写,枚举 a 中的一个区间,然后把 b 中出现了相同数的位置标记,求出未被标记的最大子段和,反过来做可以用链表维护和取 max,可以做到 O(n^2),理论上有 60 分的高分。

然后到这里似乎就做不下去了,这个时候需要观察到一个性质:如果 a,b 都有被选,那么至少有一个序列选的价值总和要大于该序列所有价值总和的一半,原因是如果都小于一半那调整成全选某个序列一定不劣。

于是假设 a 选的价值总和大于总价值的一半,b 反过来同理,那么第一个满足前缀和大于总和一半的位置 p 一定会被中,原因显然。

那我们现在可以确定 a 序列中有一个位置必选,考虑对 b 序列扫描线,扫描到某个右端点 r 时,维护所有左端点的答案。对于所有左端点,设 fl_i 表示选择区间 [i,r] 后,在 a 序列中 p 所在的极长连续段的左端点;fr_i 同理。我们在右端点增大时,需要支持对所有 fl_i\max,对所有 fr_i\min,并更新区间和。观察到 fl_i 从右往左单调不减,fr_i 从右往左单调不增,于是可以用直接维护连续段/单调栈,最后需要支持区间加、区间查询,用线段树维护,时间复杂度 O(n \log n)

Code:

#include <bits/stdc++.h>
using namespace std;
#define lc(x) x<<1
#define rc(x) x<<1|1
#define ll long long
int n, m, tp1, tp2, al, ar, bl, br, mxl[1000010], mnr[1000010], tagl[2000010], tagr[2000010], st1[500010], st2[500010];
ll ans, pre[500010], pre2[500010], tag[2000010];
struct node{
    int x, val;
}a[500010], b[500010], c[500010];
struct segment{
    ll mx;
    int pos, ls, rs;
}d[2000010];
segment operator + (segment x, segment y){
    segment ret;
    ret.mx = max(x.mx, y.mx);
    if (ret.mx == x.mx) ret.pos = x.pos, ret.ls = x.ls, ret.rs = x.rs;
    else ret.pos = y.pos, ret.ls = y.ls, ret.rs = y.rs;
    return ret;
}
void pushdown(int k){
    if (tag[k]){
        tag[lc(k)] += tag[k], tag[rc(k)] += tag[k];
        d[lc(k)].mx += tag[k], d[rc(k)].mx += tag[k];
        tag[k] = 0;
    }
    if (tagl[k]){
        tagl[lc(k)] = tagl[rc(k)] = d[lc(k)].ls = d[rc(k)].ls = tagl[k], tagl[k] = 0;
    }
    if (tagr[k]){
        tagr[lc(k)] = tagr[rc(k)] = d[lc(k)].rs = d[rc(k)].rs = tagr[k], tagr[k] = 0;
    }
}
void build(int k, int l, int r){
    tag[k] = tagl[k] = tagr[k] = 0;
    if (l == r){
        d[k].mx = -pre2[l-1] + pre[n];
        d[k].pos = l, d[k].ls = 1, d[k].rs = n;
        return ;
    }
    int mid = l + r >> 1;
    build(lc(k), l, mid);
    build(rc(k), mid+1, r);
    d[k] = d[lc(k)] + d[rc(k)];
}
void modify(int k, int l, int r, int x, int y, int ls, int rs){
    if (x <= l && r <= y){
        if (ls){
            d[k].mx -= (pre[ls-1] - pre[d[k].ls-1]), tag[k] -= (pre[ls-1] - pre[d[k].ls-1]);
            d[k].ls = tagl[k] = ls;
        }
        if (rs){
            d[k].mx -= (pre[d[k].rs] - pre[rs]), tag[k] -= (pre[d[k].rs] - pre[rs]);
            d[k].rs = tagr[k] = rs;
        }
        return ;
    }
    int mid = l + r >> 1;
    pushdown(k);
    if (x <= mid) modify(lc(k), l, mid, x, y, ls, rs);
    if (y > mid) modify(rc(k), mid+1, r, x, y, ls, rs);
    d[k] = d[lc(k)] + d[rc(k)];
}
segment query(int k, int l, int r, int x, int y){
    if (x <= l && r <= y) return d[k];
    int mid = l + r >> 1;
    pushdown(k);
    if (y <= mid) return query(lc(k), l, mid, x, y);
    if (x > mid) return query(rc(k), mid+1, r, x, y);
    return query(lc(k), l, mid, x, y) + query(rc(k), mid+1, r, x, y);
}
void work(int op){
    int pos = 0;
    for (int i=1; i<=n; i++) pre[i] = pre[i-1] + a[i].val;
    for (int i=1; i<=n; i++){
        if (pre[i] > pre[n] / 2){
            pos = i;
            break;
        }
    }
    for (int i=1; i<=n+m; i++) mxl[i] = 1, mnr[i] = n;
    for (int i=1; i<=pos; i++) mxl[a[i].x] = i + 1;
    for (int i=n; i>=pos; i--) mnr[a[i].x] = i - 1;
    for (int i=1; i<=m; i++) pre2[i] = pre2[i-1] + b[i].val;
    build(1, 1, m);
    tp1 = tp2 = 0;
    for (int i=1; i<=m; i++){
        while (tp1 && mxl[b[st1[tp1]].x] < mxl[b[i].x]){
            modify(1, 1, m, st1[tp1-1]+1, st1[tp1], mxl[b[i].x], 0);
            tp1 --;
        }
        while (tp2 && mnr[b[st2[tp2]].x] > mnr[b[i].x]){
            modify(1, 1, m, st2[tp2-1]+1, st2[tp2], 0, mnr[b[i].x]);
            tp2 --;
        }
        modify(1, 1, m, i, i, mxl[b[i].x], mnr[b[i].x]);
        segment ret = query(1, 1, m, 1, i);
        if (ret.mx + pre2[i] > ans){
            ans = ret.mx + pre2[i], al = ret.ls, ar = ret.rs, bl = ret.pos, br = i;
            if (op == 2) swap(al, bl), swap(ar, br);
        }
        st1[++tp1] = i, st2[++tp2] = i;
    }
}
int main(){
    scanf ("%d%d", &n, &m);
    for (int i=1; i<=n; i++){
        scanf ("%d", &a[i].x);
    }
    for (int i=1; i<=n; i++){
        scanf ("%d", &a[i].val);
    }
    for (int i=1; i<=m; i++){
        scanf ("%d", &b[i].x);
    }
    for (int i=1; i<=m; i++){
        scanf ("%d", &b[i].val);
    }
    work(1);
    if (pre[n] >= ans) ans = pre[n], al = 1, ar = n, bl = br = 0;
    if (pre2[m] >= ans) ans = pre2[m], al = ar = 0, bl = 1, br = m;
    memcpy(c, a, sizeof(a));
    memcpy(a, b, sizeof(b));
    memcpy(b, c, sizeof(c));
    swap(n, m);
    work(2);
    printf ("%lld\n%d %d\n%d %d\n", ans, al, ar, bl, br);
    return 0;
}