题解 CF817D 【Imbalanced Array】
Velix
2020-08-11 21:29:37
### 题意简述
给你一个长度为 $n$ 序列 $a_i$,让你求出
$$\sum\limits_{1\leqslant i\leqslant j\leqslant n,i\neq j} {\max\limits_{i\leqslant k\leqslant j}\{a_k\}}-{\min\limits_{i\leqslant k\leqslant j}\{a_k\}}$$
---
### 转化题意
如果直接暴力统计,时间复杂度为 $O(n^3)$,显然无法承受。
所以我们需要对计算方法进行优化。
我们可以考虑对于每个 $a_i$ 它可以对多少个区间进行贡献。
这里以最大值为例。
显然,对于每一个 $a_i$,它的左端点可选区间为此图的 $[l,i]$,右端点可选区间为此图的 $[i,r]$。
![](https://cdn.luogu.com.cn/upload/image_hosting/q7imq639.png)
我们可以看到,对于节点 $a_i$,我们可以发现:
其左端点的极左点为**在该数左边且从后往前数第一个大于** $a_i$ **的数的位置**,我们将其记录为 $l_i$,特别地,如果没有比它大的,那么 $l_i$ 为 $0$;
其右端点的极有点为**在该数右边且从前往后数第一个大于** $a_i$ **的数的位置**,我们将其记录为 $r_i$,特别地,如果没有比它大的,那么 $r_i$ 为 $n+1$。
而且我们可以看到,对于每一个数,由于左右端点每个可以各选一个。由乘法原理可得每个 $a_i$ 对于答案的贡献为
$$a_i*(i-l_i)*(r_i-i)$$
最后求出每一个数对答案的贡献并求和。
---
### 单调栈优化
前面讲到了如何利用左右端点求出所有的贡献,但暴力求 $l_i$ 和 $r_i$ 的时间复杂度为 $O(n^2)$ 对于 $10^6$ 的数据范围依旧远远不够。
所以我们要祭出一个好东西:**单调栈**
单调栈,可以简单理解为不会从队头出只会从队尾出的单调队列,**但只能从队尾取出**。
也就是说我们可以省略掉 $head$(后面为了适应栈用 $head$ 代替了 $tail$)
以 $l_i$ 为例,我们可以写出递推式:
$$l_i = \max\limits_{1\leqslant j\leqslant i,a_j>a_i}\{j\}$$
但由于要取 $\max$ 所以我们必须在尾端取,这时候便可以用单调栈维护。
**详细过程:**
最开始将 $0$ 入栈;
对于每次查询,如果 $a_i>a_{stack_{head}}$,那么 $l_i=i-1$;
否则从将栈中的元素一一弹出,直到 $a_i>a_{stack_{head}}$;
将 $l_i$ 赋值为 $stack_{head}$,并将 $i$ 入栈。
---
### 一些提示
- 答案会爆 $longlong$
- **注意,求** $r_i$ **时栈顶必须大于等于** $r_i$ **而不是大于**
- 记得求 $\max$ 时将 $a_0$ 和 $a_{n+1}$ 赋值为无穷大
---
### $Code$
```cpp
#include<bits/stdc++.h>
using namespace std;
int sta[1000002],a[1000002],n,mxl[1000001],mxr[1000001],mil[1000001],mir[1000001],h;
long long ans;
int main()
{
cin>>n;
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
sta[1]=0;h=1;
for(int i=1;i<=n;i++)
{
if(a[i]>a[sta[h]])sta[++h]=i,mil[i]=i-1;
else
{
while(a[i]<a[sta[h]])h--;
mil[i]=sta[h];
sta[++h]=i;
}
}
sta[1]=n+1;h=1;
for(int i=n;i>=1;i--)
{
if(a[i]>a[sta[h]])sta[++h]=i,mir[i]=i+1;
else
{
while(a[i]<=a[sta[h]])h--;
mir[i]=sta[h];
sta[++h]=i;
}
}
a[0]=a[n+1]=0x7fffffff;
sta[1]=0;h=1;
for(int i=1;i<=n;i++)
{
if(a[i]<a[sta[h]])sta[++h]=i,mxl[i]=i-1;
else
{
while(a[i]>a[sta[h]])h--;
mxl[i]=sta[h];
sta[++h]=i;
}
}
sta[1]=n+1;h=1;
for(int i=n;i>=1;i--)
{
if(a[i]<a[sta[h]])sta[++h]=i,mxr[i]=i+1;
else
{
while(a[i]>=a[sta[h]])h--;
mxr[i]=sta[h];
sta[++h]=i;
}
}
//for(int i=1;i<=n;i++)
//cout<<mil[i]<<' '<<mir[i]<<' '<<mxl[i]<<' '<<mxr[i]<<endl;
for(int i=1;i<=n;i++)
ans+=1ll*a[i]*(1ll*(i-mxl[i])*(mxr[i]-i)-1ll*(i-mil[i])*(mir[i]-i));
cout<<ans;
}
```