题解:CF1647F Madoka and Laziness

· · 题解

更好的阅读体验

把一个序列拆成两个单峰子序列,每个子序列的峰值都是该子序列的最大值,因此原序列的最大值一定是其中一个子序列的峰值。那么问题就转化为有多少个位置可能成为另一个子序列的峰值。

设原序列最大值的位置为 maxpos,我们先考虑第二个子序列的峰值在 maxpos 右边,如图。

假设我们确定了第二个序列峰值的位置 p,图中黑色折线为 maxpos 所在子序列,红色折线为 p 所在的子序列。则我们可以看到 [1,n] 的区间被分成了三部分:

我们考虑一个贪心。因为我们已经确定了 p[1, p) 的递增序列的末项越小就越有可能接上;(p, n] 的递减序列的首项越小就越有可能接上。我们可以设计一个 dp,求满足条件的数列的首项或末项最小或最大可以到多少。

具体地,我们设 f_{0, i} 表示 [1, i] 拆成两个递增序列(不妨用上图的红、黑来表示),强制 i 在黑色序列上,红色序列的末项最小是多少;f_{1, i} 表示 [i, n] 拆成两个递减序列,强制 i 在黑色上,红色序列的首相最小是多少。

这部分是好转移的,以 f_{0, i} 的转移为例,若 a_{i-1} < a_i,则说明 i-1 也可以在黑色序列上,因此可以用 f_{0, i-1} 来转移 f_{0, i};若 f_{0, i-1} < a_i,则说明 i-1 也可以在红色序列上,因此可以用 a_{i-1} 来转移 f_{0, i}f_{1, i} 的转移就是倒过来。

再考虑 g_{0, i} 表示把 [maxpos, i] 拆成递减的黑色序列和递增的红色序列,强制 i 在红色序列上,黑色序列的末项的最大值,以及 g_{1, i} 表示把 [maxpos, i] 拆成递减的黑色序列和递增的红色序列,强制 i 再黑色序列上,红色序列末项的最小值。那么我们就从前往后转移,和上面是类似的,这里就不在赘述了。

到这里,因为我们刚才钦定了 p > maxpos,因此需要把原序列 reverse 一遍再做,总复杂度就是 O(n)

#include<bits/stdc++.h>
#define int long long
#define endl '\n'
#define N 500006
using namespace std;
inline void chkmax(int &x,int y){x=x<y?y:x;}
inline void chkmin(int &x,int y){x=x<y?x:y;}
int n,a[N],f[2][N],g[2][N],ans;
void solve()
{
    int maxn=-1e15,mxpos;
    for(int i=1;i<=n;i++)
        if(a[i]>maxn)maxn=a[i],mxpos=i;
    memset(f[0],0x3f,sizeof(f[0])),f[0][0]=-1e15;
    for(int i=1;i<=mxpos;i++)
    {
        if(a[i-1]<a[i])chkmin(f[0][i],f[0][i-1]);
        if(f[0][i-1]<a[i])chkmin(f[0][i],a[i-1]);
    }
    memset(f[1],0x3f,sizeof(f[1])),f[1][n+1]=-1e15;
    for(int i=n;i>=mxpos;i--)
    {
        if(a[i+1]<a[i])chkmin(f[1][i],f[1][i+1]);
        if(f[1][i+1]<a[i])chkmin(f[1][i],a[i+1]);
    }
    memset(g[0],-0x3f,sizeof(g[0]));
    memset(g[1],0x3f,sizeof(g[1])),g[1][mxpos]=f[0][mxpos];
    for(int i=mxpos+1;i<=n;i++)
    {
        if(a[i-1]<a[i])chkmax(g[0][i],g[0][i-1]);
        if(a[i-1]>a[i])chkmin(g[1][i],g[1][i-1]);
        if(g[1][i-1]<a[i])chkmax(g[0][i],a[i-1]);
        if(g[0][i-1]>a[i])chkmin(g[1][i],a[i-1]);
    }
    for(int i=mxpos+1;i<=n;i++)
        if(g[0][i]>f[1][i])ans++;
}
main()
{
    scanf("%lld",&n);
    for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
    solve(),reverse(a+1,a+1+n),solve();
    printf("%lld\n",ans);
    return 0;
}