题解:AT_utpc2020_e Sort Segments

· · 题解

题目链接

题意比较简单,此处不再赘述。

本题主流有两种做法,一种是扫描线,另一种是一个神秘 dp。这篇题解只讲解码量和常数都更小的后者。

首先我们发现,有些位置 i 可以不被任何一次操作覆盖。为了方便统计,对于所有不被覆盖的 i,我们假设在 i 处执行了一次将 [i,i] 升序排序的操作,这样就要求所有位置都被操作了。这是第一步转化。

因为这是一个计数题,要求我们不重不漏地计算出每一种方案。于是我们需要考虑,什么样的情况会导致重复计算,以及如何去重。

对于一个序列 \{1,2,3\},发现我们无论对它如何操作,它都不会变。而对于区间 \{3,2,5,4\},发现我们无论是操作 [1,2][3,4] 还是只操作 [1,4],最终都会得到 \{2,3,4,5\}

所以不难总结出一个方案:

优先用长度更小的操作区间去代替长度更大的操作区间。

对于操作 [l,r],若存在 k\in[l,r) 使得 \max_{i=l}^ka_i<\min_{i=k+1}^ra_i,则操作 [l,r] 不合法。

因为我们可以用操作 [l,k][k+1,r] 去替代操作 [l,r]

于是考虑 dp。

dp_i 表示将 1\sim i 用操作区间覆盖完的方案数。定义 w(l,r) 表示操作区间 [l,r] 是否合法 (若合法则为 1,否则为 0)。

则转移式为:

dp_i=\sum_{j=0}^{i-1} dp_j\times w(j+1,i)

初始化 dp_0=1,答案为 dp_n

暴力做可以做到 O(n^3)

如何进行优化?

考虑将拥有相同的 w(l,i)l 一起计算。如何找到这样的 l 是我们下一步的目标。

对于 i,找到最大的 y 使得 y<ia_y<a_i,最大的 x 使得 x<ya_x>a_i;记录 z(x,i] 中最小的 a 所在位置,可知 x<z\le y。先暂不考虑 x,y 不存在的情况。

于是我们对 1\sim i 进行了分段。分别为 [1,x],(x,y](y,i]。这样便可以发现一些关于 w(l,i) 的特点:

那么对于第三种情况,我们该如何快速求出答案呢?

我们发现,因为判断一个区间合不合法依靠的是前缀最大值与后缀最小值,而 a_z 又是 [z,i] 中的最小值,于是将操作区间 [l,z] 扩展为 [l,i] 并不会影响当 k\le z 时的后缀最小值,此时只需要判断 k>z 是否会影响区间的合法性即可。

那么什么时候 k>z 会使得 w(l,z)\neq w(l,i) 呢?

发现当 l\in[x,z] 时,w(l,z)=1w(l,i)=0。于是只要把这一部分的贡献减去即可。

sum_i=\sum_{j=0}^{i-1}dp_i,于是有转移:

dp_i=(sum_i-sum_y)+[dp_z-(sum_z-sum_x)]

然后发现这个东西,只要知道了 x,y,z 后,转移是 O(1) 的!

而对于每一个 i 预处理 x,y,z 只需要 O(n\log n)。于是总时间复杂度 O(n\log n)

还需要注意一些边界情况,就如上文暂时跳过的不存在 x,y 的情况。如果 y 不存在,则说明 a_i 是前缀最小值,任意的 w(l,i) 都为 1,直接转移即可。如果 x 不存在,则就不用去找 z,直接让 dp_i=sum_i-sum_y 即可。

预处理 x 用二分或 set 可以搞定,预处理 y 可以用单调栈,预处理 z 可以用 st 表。

下面是本题代码,为了增加可读性,删去了一些常数优化。

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+5,mod=998244353;
int n,a[N],p[N];//p是用来辅助求x的 
int x[N],y[N];//i对应的x,y 
int stk[N],top;//单调栈 
int st[20][N];//st表 
int sum[N],dp[N];//dp数组及其前缀和优化数组 
set<int> q;
int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
bool cmp(int i,int j){return a[i]>a[j];}
int query(int l,int r)
{
    int ll=log2(r-l+1);
    int p1=st[ll][l],p2=st[ll][r-(1<<ll)+1];
    return (a[p1]<a[p2])?p1:p2;
}
int main()
{
    n=read();
    for(int i=1;i<=n;i++)
    {
        a[i]=read();
        p[i]=st[0][i]=i;

        while(top>0&&a[stk[top]]>=a[i]) top--;//单调栈求出y 
        y[i]=stk[top];
        stk[++top]=i;
    }
    sort(p+1,p+n+1,cmp);
    for(int i=1;i<=n;i++)//set求出x 
    {
        q.insert(p[i]);
        if(!y[p[i]]) continue;
        auto it=q.lower_bound(y[p[i]]);
        if(it!=q.begin()) it--,x[p[i]]=(*it);
    }

    int l=log2(n);//st表求区间最小值所在位置,用于求z 
    for(int i=1;i<=l;i++)
    {
        for(int j=1;j+(1<<i)-1<=n;j++)
        {
            int p1=st[i-1][j],p2=st[i-1][j+(1<<i-1)];
            st[i][j]=(a[p1]<a[p2])?p1:p2;
        }
    }

    dp[0]=sum[1]=1;
    dp[1]=1,sum[2]=2;//初始化,直接初始化了两位 
    for(int i=2;i<=n;i++)
    {
        dp[i]=(sum[i]-sum[y[i]]+mod)%mod;//不用判y[i]是否存在,不存在时直接取 y[i]=0 也是对的 
        if(x[i])
        {
            int z=query(x[i]+1,i);
            dp[i]=((1ll*dp[i]+dp[z]-(sum[z]-sum[x[i]]))%mod+mod)%mod;
        }
        sum[i+1]=(sum[i]+dp[i])%mod;
    }
    printf("%d\n",dp[n]);
    return 0;
}