CF2041J Bottle Arrangement

· · 题解

首先考虑暴力 dp,我们按照 b_i 从大到小插入到序列中。

也就是设 f_{l,r} 表示将 b 的前 r-l+1 大放到区间 [l,r] 中的答案。那么转移就是枚举当前的 b_{r-l+1} 放到 l 还是 r,判断一下代价就可以了。

时间复杂度 \mathcal O(n^2)

于是考虑优化。

发现我们从大往小扫 b_i 的过程当中,能放 b_i 的位置是不断增多的。下面设当前扫到的 ib_i = x

也就是说我们可以把序列 a 每个元素看成 011 则表示当前 a_j \gt x,反之就是 0

那么如果一个区间 [l,r] 包含 0,那么 f_{l,r} = + \infty,是没有用的。

所以我们只需要去考虑那些 1 构成的连续段,而且连续段越长越好,因为这些为 1 的位置对于之后更小的 b_i 都是本质相同的了,所以我们并不关心他们究竟是什么。

现在的思路就是去维护每个 1 的连续段的答案, 每个答案表示把 \ge xb_i 放到这个连续段中的一个区间的代价最小值。

设现在扫到了第 i 个,如果一个连续段长度 \lt i-1,那么就一定不合法了。

因为 bi 个数根本放不进去,并且它在被合并到一个合法的连续段之前就永远不合法了(答案为 +\infty)。

反之如果 len \ge i,那么之前是合法的,现在也是合法的。

而对于 len = i-1 的情况就要特殊一些,因为如果这个连续段 [l,r] 旁边 a_{l-1}=x 或者 a_{r+1}=x,那么我们把 x 放到 l-1/r+1 是可以花费 1 的代价让连续段继续合法;反之就是不合法的。

那么每一次我们用一个 set 维护当前合法的连续段的左端点和长度就可以了,对于那些实际长度 i-1 但是可以花费 1 的代价变得合法的,我们就把它看成长度为 i 就可以了。

然后每次处理完删去不合法的连续段,其实就是把长度 \lt i 的删了。

每次扫从 i \to i+1 时,我们要做的只是让有些位置变成 1

而这些位置变成 1 显然可以合并一些连续段,而合并两个连续段实际上就是答案取 \min

这个东西可以用并查集维护。

这样就做完了,实现中注意第一个位置有一些细节,时间复杂度 \mathcal O(n \log n)。代码。

#include <bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define fi first
#define se second

const int N=1e6+5,inf=1e9;
int n,a[N],b[N],fa[N],ans[N],sz[N],len[N],id[N];
bool vis[N];
set<pii> S;

int find(int x){return x==fa[x]?x:fa[x]=find(fa[x]);}

void ckmn(int &a,int b){a=min(a,b);}

void merge(int u,int v){
  u=find(u),v=find(v);
  if(u==v) return;
  if(ans[u]!=inf) S.erase({sz[u]+len[u],u});
  if(ans[v]!=inf) S.erase({sz[v]+len[v],v});
  fa[v]=u,sz[u]+=sz[v],len[u]=len[v]=0;
  ans[u]=min(ans[u],ans[v]);

  if(ans[u]!=inf) S.insert({sz[u],u});
}

int main(){
  ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
  cin>>n;
  for(int i=1;i<=n;i++) cin>>a[i],ans[i]=inf,sz[i]=1,fa[i]=id[i]=i;
  for(int i=1;i<=n;i++) cin>>b[i];

  if(n==1){
    if(a[1]>b[1]) cout<<"0\n";
    else if(a[1]==b[1]) cout<<"1\n";
    else cout<<"-1\n";
    return 0;
  }

  sort(id+1,id+n+1,[&](int x,int y){return a[x]>a[y];});
  sort(b+1,b+n+1,[&](int x,int y){return x>y;});

  for(int i=1,j=1;i<=n;i++){
    while(j<=n&&a[id[j]]>b[i]){
      int pos=id[j];
      if(i==1) ans[pos]=0,S.insert({1,pos});
      else if(i==2&&a[pos]==b[1]) ans[pos]=1,S.insert({1,pos});

      if(vis[pos-1]) merge(pos-1,pos);
      if(vis[pos+1]) merge(pos,pos+1);
      vis[pos]=1;
      ++j;
    }
    int k=j;
    while(k<=n&&a[id[k]]==b[i]){
      int pos=id[k];
      if(vis[pos-1]){
        int l=find(pos-1);
        if(ans[l]!=inf&&sz[l]==i-1&&!len[l]){
          S.erase({sz[l],l});
          S.insert({sz[l]+1,l}),++len[l],++ans[l];
        }
      }
      if(vis[pos+1]){
        int l=pos+1;
        if(ans[l]!=inf&&sz[l]==i-1&&!len[l]){
          S.erase({sz[l],l});
          S.insert({sz[l]+1,l}),++len[l],++ans[l];
        }
      }
      ++k;
    }

    while(S.size()&&(*S.begin()).fi<i) ans[(*S.begin()).se]=inf,S.erase(S.begin());
  }
  if(!S.size()) cout<<"-1\n";
  else cout<<ans[(*S.begin()).se]<<'\n';
  return 0;
}