题解 CF817D 【Imbalanced Array】

· · 题解

题意简述

给你一个长度为 n 序列 a_i,让你求出

\sum\limits_{1\leqslant i\leqslant j\leqslant n,i\neq j} {\max\limits_{i\leqslant k\leqslant j}\{a_k\}}-{\min\limits_{i\leqslant k\leqslant j}\{a_k\}}

转化题意

如果直接暴力统计,时间复杂度为 O(n^3),显然无法承受。

所以我们需要对计算方法进行优化。

我们可以考虑对于每个 a_i 它可以对多少个区间进行贡献。

这里以最大值为例。

显然,对于每一个 a_i,它的左端点可选区间为此图的 [l,i],右端点可选区间为此图的 [i,r]

我们可以看到,对于节点 a_i,我们可以发现:

其左端点的极左点为在该数左边且从后往前数第一个大于 a_i 的数的位置,我们将其记录为 l_i,特别地,如果没有比它大的,那么 l_i0

其右端点的极有点为在该数右边且从前往后数第一个大于 a_i 的数的位置,我们将其记录为 r_i,特别地,如果没有比它大的,那么 r_in+1

而且我们可以看到,对于每一个数,由于左右端点每个可以各选一个。由乘法原理可得每个 a_i 对于答案的贡献为

a_i*(i-l_i)*(r_i-i)

最后求出每一个数对答案的贡献并求和。

单调栈优化

前面讲到了如何利用左右端点求出所有的贡献,但暴力求 l_ir_i 的时间复杂度为 O(n^2) 对于 10^6 的数据范围依旧远远不够。

所以我们要祭出一个好东西:单调栈

单调栈,可以简单理解为不会从队头出只会从队尾出的单调队列,但只能从队尾取出

也就是说我们可以省略掉 head(后面为了适应栈用 head 代替了 tail

l_i 为例,我们可以写出递推式:

l_i = \max\limits_{1\leqslant j\leqslant i,a_j>a_i}\{j\}

但由于要取 \max 所以我们必须在尾端取,这时候便可以用单调栈维护。

详细过程:

最开始将 0 入栈;

对于每次查询,如果 a_i>a_{stack_{head}},那么 l_i=i-1

否则从将栈中的元素一一弹出,直到 a_i>a_{stack_{head}}

l_i 赋值为 stack_{head},并将 i 入栈。

一些提示

Code

#include<bits/stdc++.h>
using namespace std;
int sta[1000002],a[1000002],n,mxl[1000001],mxr[1000001],mil[1000001],mir[1000001],h;
long long ans;
int main()
{
    cin>>n;
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
    sta[1]=0;h=1;
    for(int i=1;i<=n;i++)
    {
        if(a[i]>a[sta[h]])sta[++h]=i,mil[i]=i-1;
        else
        {
            while(a[i]<a[sta[h]])h--;
            mil[i]=sta[h];
            sta[++h]=i;
        }
    }
    sta[1]=n+1;h=1;
    for(int i=n;i>=1;i--)
    {
        if(a[i]>a[sta[h]])sta[++h]=i,mir[i]=i+1;
        else
        {
            while(a[i]<=a[sta[h]])h--;
            mir[i]=sta[h];
            sta[++h]=i;
        }
    }
    a[0]=a[n+1]=0x7fffffff;
    sta[1]=0;h=1;
    for(int i=1;i<=n;i++)
    {
        if(a[i]<a[sta[h]])sta[++h]=i,mxl[i]=i-1;
        else
        {
            while(a[i]>a[sta[h]])h--;
            mxl[i]=sta[h];
            sta[++h]=i;
        }
    }
    sta[1]=n+1;h=1;
    for(int i=n;i>=1;i--)
    {
        if(a[i]<a[sta[h]])sta[++h]=i,mxr[i]=i+1;
        else
        {
            while(a[i]>=a[sta[h]])h--;
            mxr[i]=sta[h];
            sta[++h]=i;
        }
    }
    //for(int i=1;i<=n;i++)
        //cout<<mil[i]<<' '<<mir[i]<<' '<<mxl[i]<<' '<<mxr[i]<<endl; 
    for(int i=1;i<=n;i++)
        ans+=1ll*a[i]*(1ll*(i-mxl[i])*(mxr[i]-i)-1ll*(i-mil[i])*(mir[i]-i));
    cout<<ans;
}