P7164 [COCI2020-2021#1] 3D Histogram (分治+线段树+凸包)

· · 题解

题面链接

这里介绍一下官方题解的做法

考虑分治,那么重点就来到了怎么计算横跨区间中点的贡献

f(i,j) 表示区间 [i,j]a 的最小值,g(i,j) 表示区间 [i,j]b 的最小值。

那么一个区间 [l,r] 的贡献就是 f(l,r) \times g(l,r) \times (r-l+1)

设分治区间中点为 mid,接下来对 f(i,j)g(i,j) 的来源做讨论:

  1. 两者都来自 mid 左边
  2. 两者都来自 mid 右边
  3. f 的最小值来自左边,g 的最小值来自右边
  4. g 的最小值来自左边,f 的最小值来自右边

其中情况 1 和情况 2 很好处理,双指针扫一遍即可。对于每一个确定的 l 找到一个 r 使得 f(mid+1,r)g(mid+1,r) 恰好大于 f(l,mid)g(l,mid)。因为随着 l 的左移 f(l,mid)g(l,mid) 都是单调减少的所以 r 单调往右能走多远走多远即可。

难点在于情况3、4的处理,这里以情况3为例:

拆一下铈子:

f(l,mid) \times g(mid+1,r) \times (r-l+1) = -l \times f(l,mid) \times g(mid+1,r) + (1+r) \times f(l,mid) \times g(mid+1,r) = ((l \times f(l,mid)) , f(l,mid)) \cdot (-g(mid+1,r) , (1+r) \times (g(mid+1,r))

因为对于每一个确定的 l ,r 的取值范围是固定的(方法和上面情况1、2类似),所以我们要快速求出在区间内与 l 的那个向量作点积的值最大的 r 的那个向量。

然后我们发现了这道题

算一下复杂度:分治+线段树维护凸包+三分,O(N\log^3N),肯定是过不去的。

然后我们发现,l \times f(l,mid)f(l,mid) 都是随着 l 单调递减的,所以说 l 对应的向量在不断地逆时针转动,所以线段树维护的每个区间内的最优解也应该是单调地逆时针转动(思考一下为什么)。对于每次分治求解,我们对线段树的每个节点记录一下上次查询得到的最优值的位置,单调地向一个方向移动,这样我们就省去了三分的复杂度,O(N\log^2N)成功 AC。

本人人傻常数大,开了 O2 才勉强过了的样子,代码十分不建议参考

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ls (x << 1)
#define rs ((x << 1) | 1)
#define int long long 
using namespace std;
const int MAXN = 2e6 + 50;
const int MAXM = 25;
const int INF = 0x3f3f3f3f3f3f3f3f;
int w[MAXN], h[MAXN], minw[MAXN][MAXM], minh[MAXN][MAXM];
int idx[MAXN], lg[MAXN];
int N;
struct node
{
    int x, y;
    node(){}
    node(int a, int b)
    {
        x = a;
        y = b;
    }
    bool operator<(const node &tmp) const
    {
        return (x == tmp.x) ? y < tmp.y : x < tmp.x;
    }
    node operator+(const node &tmp) const
    {
        return node(x + tmp.x, y + tmp.y);
    }
    node operator-(const node &tmp) const
    {
        return node(x - tmp.x, y - tmp.y);
    }
    int operator*(const node &tmp) const
    {
        return x * tmp.y - y * tmp.x;
    }
} s[MAXN * 10], po[MAXN], q[MAXN * 2], p[MAXN];
int tot, ans;
struct SegT
{
    int lt[MAXN * 4], rt[MAXN * 4], pos[MAXN * 4];
    void build(int x, int l, int r)
    {
        int cnt = 0;
        for (int i = l; i <= r; ++i)
            p[++cnt] = po[i];
        // sort(p + 1, p + cnt + 1);
        int top = 0;
        for (int i = 1; i <= cnt; ++i)
        {
            while (top > 1 && (q[top] - q[top - 1]) * (p[i] - q[top]) >= 0)
                top--;
            q[++top] = p[i];
        }
        lt[x] = tot + 1;
        for (int i = 1; i <= top; ++i)
        {
            s[++tot] = q[i];
            q[i] = node(0, 0);
        }
        rt[x] = tot;
        pos[x] = rt[x];
    }
    void update(int x, int l, int r, int t)
    {
        if (r == t)
            build(x, l, r);
        if (l == r)
            return;
        int mid = (l + r) >> 1;
        if (t <= mid)
            update(ls, l, mid, t);
        else
            update(rs, mid + 1, r, t);
    }
    bool check(int x, node t)
    {
        return (s[x - 1].x * t.x + s[x - 1].y * t.y >= s[x].x * t.x + s[x].y * t.y);
    }
    int calc(int x, node t)
    {
        int now = pos[x];
        while (check(now, t) && now > lt[x])
            now--;
        pos[x] = now;
        return s[now].x * t.x + s[now].y * t.y;
    }
    int query(int x, int l, int r, int tl, int tr, node t)
    {
        if (tl > r || tr < l)
            return -INF;
        if (tl <= l && tr >= r)
            return calc(x, t);
        int mid = (l + r) >> 1;
        return max(query(ls, l, mid, tl, tr, t), query(rs, mid + 1, r, tl, tr, t));
    }
} T;
void rmqpre()
{
    for (int i = 1; i <= N; ++i)
    {
        minw[i][0] = w[i];
        minh[i][0] = h[i];
    }
    for (int j = 1; j <= 20; ++j)
        for (int i = 1; i <= N; ++i)
        {
            minw[i][j] = min(minw[i][j - 1], minw[i + (1 << (j - 1))][j - 1]);
            minh[i][j] = min(minh[i][j - 1], minh[i + (1 << (j - 1))][j - 1]);
        }
}
inline int geth(int l, int r)
{
    int len = r - l + 1;
    int k = lg[len];
    return min(minh[l][k], minh[r - (1 << k) + 1][k]);
}
inline int getw(int l, int r)
{
    int len = r - l + 1;
    int k = lg[len];
    return min(minw[l][k], minw[r - (1 << k) + 1][k]);
}
void calc1(int l, int r)
{
    int mid = (l + r) >> 1;
    int lt = mid;
    int rt = mid;
    while (lt >= l)
    {
        while (getw(mid + 1, rt + 1) >= getw(lt, mid) && geth(mid + 1, rt + 1) >= geth(lt, mid) && rt < r)
            rt++;
        if (rt > mid)
            ans = max(ans, (rt - lt + 1) * getw(lt, mid) * geth(lt, mid));
        lt--;
    }
}
void calc2(int l, int r)
{
    int mid = (l + r) >> 1;
    int lt = mid + 1;
    int rt = mid + 1;
    while(rt <= r)
    {
        while (getw(lt - 1, mid) >= getw(mid + 1, rt) && geth(lt - 1, mid) >= geth(mid + 1, rt) && lt > l)
            lt--;
        if (lt < mid + 1)
            ans = max(ans, (rt - lt + 1) * getw(mid + 1, rt) * geth(mid + 1, rt));
        rt++;
    }
}
void clear(int l, int r)
{
    tot = 0;
}
void calc3(int l, int r)//w->r,h->l
{
    int mid = (l + r) >> 1;
    int tt = 0;
    clear(l, r);
    for (int i = mid + 1; i <= r; ++i)
    {
        po[++tt] = node(-getw(mid + 1, i), getw(mid + 1, i) * (1 + i));
        T.update(1, 1, N, tt);
    }
    int lt = mid;
    int rt = mid;
    for (int i = mid; i >= l; --i)
    {
        while (getw(mid + 1, lt + 1) >= getw(i, mid) && lt < r)
            lt++;
        while (geth(mid + 1, rt + 1) >= geth(i, mid) && rt < r)
            rt++;
        if (rt > mid && rt > lt)
        {
            ans = max(ans, T.query(1, 1, N, lt - mid + 1, rt - mid, node(i * geth(i, mid), geth(i, mid))));
            // cout << mid << " " << i << ":" << lt << " " << rt << ":" << ans << " " << endl;
        }
    }
}
void calc4(int l, int r)//h->r,w->l
{
    int mid = (l + r) >> 1;
    int tt = 0;
    clear(l, r);
    for (int i = mid + 1; i <= r; ++i)
    {
        po[++tt] = node(-geth(mid + 1, i), geth(mid + 1, i) * (1 + i));
        T.update(1, 1, N, tt);
    }
    int lt = mid;
    int rt = mid;
    for (int i = mid; i >= l; --i)
    {
        while (geth(mid + 1, lt + 1) >= geth(i, mid) && lt < r)
            lt++;
        while (getw(mid + 1, rt + 1) >= getw(i, mid) && rt < r)
            rt++;
        if (rt > mid && rt > lt)
        {
            ans = max(ans, T.query(1, 1, N, lt - mid + 1, rt - mid, node(i * getw(i, mid), getw(i, mid))));
            // cout << mid << " " << i << ":" << lt << " " << rt << ":" << ans << endl;
        }
    }
}
void solve(int l, int r)
{
    if (l == r)
    {
        ans = max(w[l] * h[l], ans);
        return;
    }
    int mid = (l + r) >> 1;
    solve(l, mid);
    solve(mid + 1, r);
    calc1(l, r);
    calc2(l, r);
    calc3(l, r);
    calc4(l, r);
}
signed main()
{
    // freopen("test.in", "r", stdin);
    scanf("%lld", &N);
    for (int i = 1; i <= N; ++i)
        scanf("%lld%lld", &h[i], &w[i]);
    rmqpre();
    lg[1] = 0;
    for (int i = 2; i <= N; ++i)
        lg[i] = lg[i >> 1] + 1;
    ans = 0;
    solve(1, N);
    printf("%lld\n", ans);
    return 0;
}