[题解] P10656 「ROI 2017 Day 2」学习轨迹
考虑暴力怎么写,枚举
然后到这里似乎就做不下去了,这个时候需要观察到一个性质:如果
于是假设
那我们现在可以确定
Code:
#include <bits/stdc++.h>
using namespace std;
#define lc(x) x<<1
#define rc(x) x<<1|1
#define ll long long
int n, m, tp1, tp2, al, ar, bl, br, mxl[1000010], mnr[1000010], tagl[2000010], tagr[2000010], st1[500010], st2[500010];
ll ans, pre[500010], pre2[500010], tag[2000010];
struct node{
int x, val;
}a[500010], b[500010], c[500010];
struct segment{
ll mx;
int pos, ls, rs;
}d[2000010];
segment operator + (segment x, segment y){
segment ret;
ret.mx = max(x.mx, y.mx);
if (ret.mx == x.mx) ret.pos = x.pos, ret.ls = x.ls, ret.rs = x.rs;
else ret.pos = y.pos, ret.ls = y.ls, ret.rs = y.rs;
return ret;
}
void pushdown(int k){
if (tag[k]){
tag[lc(k)] += tag[k], tag[rc(k)] += tag[k];
d[lc(k)].mx += tag[k], d[rc(k)].mx += tag[k];
tag[k] = 0;
}
if (tagl[k]){
tagl[lc(k)] = tagl[rc(k)] = d[lc(k)].ls = d[rc(k)].ls = tagl[k], tagl[k] = 0;
}
if (tagr[k]){
tagr[lc(k)] = tagr[rc(k)] = d[lc(k)].rs = d[rc(k)].rs = tagr[k], tagr[k] = 0;
}
}
void build(int k, int l, int r){
tag[k] = tagl[k] = tagr[k] = 0;
if (l == r){
d[k].mx = -pre2[l-1] + pre[n];
d[k].pos = l, d[k].ls = 1, d[k].rs = n;
return ;
}
int mid = l + r >> 1;
build(lc(k), l, mid);
build(rc(k), mid+1, r);
d[k] = d[lc(k)] + d[rc(k)];
}
void modify(int k, int l, int r, int x, int y, int ls, int rs){
if (x <= l && r <= y){
if (ls){
d[k].mx -= (pre[ls-1] - pre[d[k].ls-1]), tag[k] -= (pre[ls-1] - pre[d[k].ls-1]);
d[k].ls = tagl[k] = ls;
}
if (rs){
d[k].mx -= (pre[d[k].rs] - pre[rs]), tag[k] -= (pre[d[k].rs] - pre[rs]);
d[k].rs = tagr[k] = rs;
}
return ;
}
int mid = l + r >> 1;
pushdown(k);
if (x <= mid) modify(lc(k), l, mid, x, y, ls, rs);
if (y > mid) modify(rc(k), mid+1, r, x, y, ls, rs);
d[k] = d[lc(k)] + d[rc(k)];
}
segment query(int k, int l, int r, int x, int y){
if (x <= l && r <= y) return d[k];
int mid = l + r >> 1;
pushdown(k);
if (y <= mid) return query(lc(k), l, mid, x, y);
if (x > mid) return query(rc(k), mid+1, r, x, y);
return query(lc(k), l, mid, x, y) + query(rc(k), mid+1, r, x, y);
}
void work(int op){
int pos = 0;
for (int i=1; i<=n; i++) pre[i] = pre[i-1] + a[i].val;
for (int i=1; i<=n; i++){
if (pre[i] > pre[n] / 2){
pos = i;
break;
}
}
for (int i=1; i<=n+m; i++) mxl[i] = 1, mnr[i] = n;
for (int i=1; i<=pos; i++) mxl[a[i].x] = i + 1;
for (int i=n; i>=pos; i--) mnr[a[i].x] = i - 1;
for (int i=1; i<=m; i++) pre2[i] = pre2[i-1] + b[i].val;
build(1, 1, m);
tp1 = tp2 = 0;
for (int i=1; i<=m; i++){
while (tp1 && mxl[b[st1[tp1]].x] < mxl[b[i].x]){
modify(1, 1, m, st1[tp1-1]+1, st1[tp1], mxl[b[i].x], 0);
tp1 --;
}
while (tp2 && mnr[b[st2[tp2]].x] > mnr[b[i].x]){
modify(1, 1, m, st2[tp2-1]+1, st2[tp2], 0, mnr[b[i].x]);
tp2 --;
}
modify(1, 1, m, i, i, mxl[b[i].x], mnr[b[i].x]);
segment ret = query(1, 1, m, 1, i);
if (ret.mx + pre2[i] > ans){
ans = ret.mx + pre2[i], al = ret.ls, ar = ret.rs, bl = ret.pos, br = i;
if (op == 2) swap(al, bl), swap(ar, br);
}
st1[++tp1] = i, st2[++tp2] = i;
}
}
int main(){
scanf ("%d%d", &n, &m);
for (int i=1; i<=n; i++){
scanf ("%d", &a[i].x);
}
for (int i=1; i<=n; i++){
scanf ("%d", &a[i].val);
}
for (int i=1; i<=m; i++){
scanf ("%d", &b[i].x);
}
for (int i=1; i<=m; i++){
scanf ("%d", &b[i].val);
}
work(1);
if (pre[n] >= ans) ans = pre[n], al = 1, ar = n, bl = br = 0;
if (pre2[m] >= ans) ans = pre2[m], al = ar = 0, bl = 1, br = m;
memcpy(c, a, sizeof(a));
memcpy(a, b, sizeof(b));
memcpy(b, c, sizeof(c));
swap(n, m);
work(2);
printf ("%lld\n%d %d\n%d %d\n", ans, al, ar, bl, br);
return 0;
}