题解 P6795 【[SNOI2020] 排列】

· · 题解

我们将答案拆成三部分,前k个数内部的连续段,后n-k个数的内部的连续段,跨越了第k个数的连续段。

第1部分是经典问题,不会的建议先做别的题比如 https://www.luogu.com.cn/problem/P4747。

之后一个结论是,考虑前k个数将值域分为的若干段,则同一段内的数在最终排列中都是相邻的。严格的证明这里没有因为我懒得想。

有了这个结论后,后n-k个数的内部的连续段数量已经可以算出来了,只有跨越的部分比较棘手。

考虑从k到1 DP计算,一个性质是,对于任意 1 \le i \le k,记 L=min_{l=i}^{k}p_l,R=max_{l=i}^{k}p_l 一定存在 k \le j \le n,使得 [i,j]构成连续段,且 R=max_{l=i}^{j}p_lL=min_{l=i}^{j}p_l,直观描述就是值域区间 [L,R] 内的空洞一定被填满了。

之后一个性质是为了便于我们DP计算而发掘的。就是如果L-1没有出现在前k个数中,则L-1所处的极长未出现在前k个数中的值域连续段一定是递减排布的;同样的R+1若没出现,则其所处的极长段一定递增排布的。

于是,对于i个位置的贡献,唯一有决策必要的就是L-1和R+1所处极长段的情况。这样我们就得到了DP状态 f_{i,s1,s2} 用以记录所有可能区间 [k+1,j] 内任意放置元素,配合区间 [i,k] 产生的最大贡献,其中 s_1,s_2 \in \{0,1,2\} 分别记录左/右极长段的情况,0表示未选中,1表示是最后一个被选中,2表示被选中且不是最后一个,这里选中定义为出现在 [k+1,j] 中。

转移的细节

具体实现可以见后文代码部分中第一个 for(i=k;i;--i) 的循环内部。

方便起见,我们只考虑 L,R 发生变化的位置 i。记上一个位置为ii。

一个性质是 [转移前后左极长段相同]+[转移前后右极长段相同]=1,因为ii到i L,R 恰好有一个发生变化。

剔除非法状态与转移

首先要剔除一些非法状态。根据定义 s_1,s_2 不能同时为1,方便起见我们强制要求 s_1 \neq 0 时左极长段必须非空,s_2 也是。这部分可以见代码第55-57行。

还要剔除一些非法转移。

如果转移前后左极长段相同,并且在ii的状态中 s_1=2(被其它东西盖住了),那么在i的状态中,s_1 就不能为1。这部分可以见代码第58行。

如果转移前后左极长段相同,并且在ii的状态中 s_1 \neq 0(已选中),那么在i的状态中,s_1 就不能为0。这部分可以见代码第61行。

如果转移前后左极长段相同,并且在转移前后 s1 均为1,就要好好讨论了。此时转移合法当且仅当左极长段可以直接继承而无需在中间插入其它东西。如果i的状态的 s_2=2,则右极长连续段一定是新盖住的,这是一种非法情况。另一种非法情况是,R变大时可能会导致新的值域空洞产生,如果需要填坑,那么也是一种非法情况。我们维护一个前缀和数组 ss2 用于维护一段区间中坑的数量,然后判一下在ii的状态中已经填上的坑的数量——这等于ii的状态中 [s_2=2]。这部分可以见代码第59-60行。

如果右极长段相同,对称地讨论即可。

计算转移的权值。

首先要考虑 [L,R] 是否可以构成连续段,亦即 [L,R] 中没有洞,且没有左右极长段产生干扰。

如果转移前后左极长段相同,则只需检查ii的状态中 s_1=2 ,或者 s_1=1 且在转移到i的过程中被盖上了——这在前面判“转移前后 s1 均为1”的时候已经讨论过了。这部分可以见代码第69-72行。

右极长段同理。

之后要考虑的是左右极长段的贡献,前面已经给了那么多讨论的示范了,这里我就咕咕咕了就留给读者自行思考了,需要提醒的是贡献值可能等于极长段长度,也可能等于1,需要仔细分析。这部分可以见代码第78-80行。

最后构方案是简单的,记下每个状态的前驱即可,往前走的时候看下左右极长段怎么放即可。

下面是loj上的AC代码,它在洛谷上还在judging。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline ll C2(int x){return 1ll*x*(x+1)>>1;}
const int N=2e5+5;
int n,k,i,ii,j,l,p[N],ip[N],mn[N],mx[N],ss[N],ss2[N],prr[N],sf[N],st1[N],st2[N],tp,pr[N];
ll C1,f[N][9];int pre[N][9];
namespace N1{
struct info{int v,c;};
inline info uni(const info&a,const info&b){
    static info c;c.v=max(a.v,b.v);
    c.c=(a.v==c.v?a.c:0)+(b.v==c.v?b.c:0);return c;
}
const info E=info{-(1<<25),0};
struct node{
    int ad,lb,rb,md;info s;
}t[N*3];
inline void upd(int i){t[i].s=uni(t[i<<1].s,t[i<<1|1].s);t[i].s.v+=t[i].ad;}
void build(int i,int l,int r){
    t[i]=node{0,l,r,l+r>>1,0,r-l+1};if(l==r)return;
    build(i<<1,l,t[i].md);build(i<<1|1,t[i].md+1,r);
}
void mdy(int i,int l,int r,int v){
    if(l<=t[i].lb && t[i].rb<=r){t[i].ad+=v;t[i].s.v+=v;return;}
    if(l<=t[i].md)mdy(i<<1,l,r,v);if(t[i].md<r)mdy(i<<1|1,l,r,v);upd(i);
}
info ask(int i,int l,int r){
    if(l<=t[i].lb && t[i].rb<=r)return t[i].s;
    info ret=uni(l<=t[i].md?ask(i<<1,l,r):E,t[i].md<r?ask(i<<1|1,l,r):E);
    ret.v+=t[i].ad;return ret;
}
inline void main(){
    build(1,1,n);
    for(i=1;i<=k;++i){
        mdy(1,1,i,-1);
        if(ip[p[i]-1])mdy(1,1,ip[p[i]-1],1);
        if(ip[p[i]+1])mdy(1,1,ip[p[i]+1],1);
        info u=ask(1,1,i);C1+=u.v==-1?u.c:0;
        ip[p[i]]=i;
    }
}
}
int main(){
    scanf("%d%d",&n,&k);for(i=1;i<=k;++i)scanf("%d",p+i),ss[p[i]]=1;
    for(i=1;i<=n;++i)ss[i]+=ss[i-1];
    mn[k]=mx[k]=p[k];for(i=k-1;i;--i)mn[i]=min(mn[i+1],p[i]),mx[i]=max(mx[i+1],p[i]);
    N1::main();
    for(i=1;i<=n;++i)pr[i]=ip[i]?0:pr[i-1]+1;for(i=n;i;--i)sf[i]=ip[i]?0:sf[i+1]+1;
    for(i=1;i<=n;++i)if(!ip[i] && (i==1 || ip[i-1]))C1+=C2(sf[i]),ss2[i]=1;
    for(i=1;i<=n;++i)ss2[i]+=ss2[i-1];
    for(i=k;i;--i)if(mn[i]!=mn[i-1] || mx[i]!=mx[i-1]){
        for(ii=i;mn[i]==mn[ii] && mx[i]==mx[ii];++ii);prr[i]=ii;
        for(j=0;j<9;++j)for(l=0;l<9;++l){
            int Lj=j%3,Rj=j/3,Ll=l%3,Rl=l/3;
            if(Lj==1 && Rj==1)continue;if(Ll==1 && Rl==1)continue;
            if(Lj && !pr[mn[i]-1])continue;if(Rj && !sf[mx[i]+1])continue;
            if(Ll && !pr[mn[ii]-1])continue;if(Rl && !sf[mx[ii]+1])continue;
            if(mn[i]==mn[ii] && Lj==1 && Ll==2)continue;
            if(mn[i]==mn[ii] && Ll==1 && Lj==1 && (Rj==2
                 || ss2[mx[i]-1]-ss2[mx[ii]]>(Rl==0 || !sf[mx[ii]+1]?0:1)))continue;
            if(mn[i]==mn[ii] && !Lj && Ll)continue;
            if(mx[i]==mx[ii] && Rj==1 && Rl==2)continue;
            if(mx[i]==mx[ii] && Rl==1 && Rj==1 && (Lj==2 
                 || ss2[mn[ii]-1]-ss2[mn[i]]>(Ll==0 || !pr[mn[ii]-1]?0:1)))continue;
            if(mx[i]==mx[ii] && !Rj && Rl)continue;
            ll nv=f[ii][l];
            if(ss[mx[i]]-ss[mn[i]-1]==k-i+1){
                int zz=mx[i]-mn[i]>k-i;
                if(mn[i]==mn[ii]){
                    if(Ll==2)zz=0;
                    if(Ll==1 && ss2[mx[i]-1]-ss2[mx[ii]]>(Rl==2?1:0))zz=0;
                }
                if(mx[i]==mx[ii]){
                    if(Rl==2)zz=0;
                    if(Rl==1 && ss2[mn[ii]-1]-ss2[mn[i]]>(Ll==2?1:0))zz=0;
                }
                nv+=zz;
                if(Lj==1 || (Lj==2 && (mn[i]==mn[ii]?Ll==0
                        || (Ll==1 && ss2[mx[i]-1]-ss2[mx[ii]]<=(Rl==2?1:0)):1)))nv+=pr[mn[i]-1];
                    else if(Lj)nv++;
                if(Rj==1 || (Rj==2 && (mx[i]==mx[ii]?Rl==0
                        || (Rl==1 && ss2[mn[ii]-1]-ss2[mn[i]]<=(Ll==2?1:0)):1)))nv+=sf[mx[i]+1];
                    else if(Rj)nv++;
            }
            if(nv>f[i][j])f[i][j]=nv,pre[i][j]=l;
        }
    }
    ll mxx=-N;int id;for(i=0;i<9;++i)if(f[1][i]>mxx)mxx=f[1][i],id=i;
    printf("%lld\n",mxx+C1);
    for(i=1;i<=k;)st1[++tp]=i,st2[tp]=id,id=pre[i][id],i=prr[i];
    for(i=k;tp;--tp){
        int x=st1[tp],y=st1[tp+1];
        if(y){
            if(mn[x]<mn[y]){for(j=mn[y]-1;j>mn[x];--j)if(!ip[j])p[++i]=j;}
                else for(j=mx[y]+1;j<mx[x];++j)if(!ip[j])p[++i]=j;
        }
        auto pushR=[&](){
            if(st2[tp]/3 && (mx[x]!=mx[y] || !(st2[tp+1]/3)))
                for(j=mx[x]+1;j<=n && !ip[j];++j)p[++i]=j,ip[j]=-1;
        };
        auto pushL=[&](){
            if(st2[tp]%3 && (mn[x]!=mn[y] || !(st2[tp+1]%3)))
                for(j=mn[x]-1;j && !ip[j];--j)p[++i]=j,ip[j]=-1;
        };
        if(st2[tp]/3==1)pushL(),pushR();
            else pushR(),pushL();
    }
    for(i=1;i<=n;++i)printf("%d%c",p[i],i==n?'\n':' ');
}