CF1601C Optimal Insertion

· · 题解

给出一个长度为 n 的序列 a 和长度为 m 的序列 b .

你可以把 b 以任意的顺序插入 a 的任意位置 ,得到一个长度为 n+m 的序列 c .

c 最少有多少个逆序对 .

多测,有 t 组数据 .

1\leq t\leq 10^4,1\leq n,m\leq 10^6,1\leq a_i\leq 10^9,1\leq b_i\leq 10^9,\sum n\leq 10^6,\sum m\leq 10^6

首先考虑 b 的插入顺序 . 猜测 b 是按照从小到大顺序插入的 .

考虑对于当前插入 , 存在位置 b_i>b_{i+1} ,插入的位置是 p_1p_2考 ,虑交换两个数 .

交换前的 a 产生的代价是区间 [0,p_1-1]a_j>b_i 数的个数 + 区间 [p_2+1,n-1]a_j<b_i 数的个数 + 区间 [0,p_1-1]a_j>b_i 的个数 + 区间 [p_2+1,n-1]a_j<b_i 数的个数 + [p_1,p_2]a_j<b_i

数的个数 + [p_1,p_2]a_j>b_{i+1} 数的个数 .

交换后的 a 产生的代价是区间 [0,p_1-1]a_j>b_i 数的个数 + 区间 [p_2+1,n-1]a_j<b_i 数的个数 + 区间 [0,p_1-1]a_j>b_i 的个数 + 区间 [p_2+1,n-1]a_j<b_i 数的个数 + [p_1,p_2]a_j<b_{i+1}

数的个数 + [p_1,p_2]a_j>b_{i} 数的个数 .

前面很长一段都是相同的 ,对于后面的部分 ,因为 b_i>b_{i+1} ,所以必定有 :

所以,交换之后不会更劣 .

因此,从小到大插入是最优的 .

b 排完序之后是 b_0b_{n-1} . 对于 b_i 来说,b_{i-1} 最优插入位置集合中最左端的位置必定小于等于 b_i 最优插入位置集合中最左端的位置 .

考虑证明, b_{i-1} 最优插入位置集合中最左端的位置为 p_1b_i 最优插入位置集合中最左端的位置为 p_2 .

那么,p_1a 产生的代价是 [0,p_1-1]a_j>b_{i-1} + [p_1,n-1]a_j<b_{i-1} .

b_i 放到 p_1 上面,[0,p_1] 产生的代价必然是不增的,[p_1,n-1] 中产生的代价必然是是不降的 .

考虑对于 b_i ,如果向左移动决策点 ,原本在 [b_{i-1},b_i-1] 的数 1\to 1 ,代价不变,在 [0,b_{i-1}] 的数会产生 0\to 1 的代价 ,[b_{i}+1,\infty) 的数产生的代价由 1\to 0 .

考虑对于 b_{i-1} ,如果向左移动决策点,原本在 [b_{i-1},b_i-1] 的数 1\to 0 ,代价减 1 ,在 [0,b_{i-1}] 的数会产生 0 \to 1 的代价,[b_i+1,\infty] 的数产生的代价是 1\to 0 .

唯一的区别在于 [b_i,b_i-1] 的点,b_i 不变化 b_{i-1} 是减一,考虑用反证法简单证明一下, p_1\leq p_2 必然成立 .

因此,p_0\leq p_1\leq \cdots \leq p_{n-1} 必然成立 . 也就是每个 b_i 都能放到最优位置上 .

接下来有两种不同的方法处理 .

sol1

建立一棵线段树,每个位置维护插入位置 i 新增加的逆序对数 .

新增加的逆序对数由两部分组成,前面比 x 大的 + 后面比 x 小的 .

从小到大依次插入 b ,那么 x 也从小变到大 .

x 变到 y 的时候 , x<y , 考虑对于线段树的影响 .

对于 a 中 ,在区间 [x+1,y] 的值对后面位置的影响会消失,操作为区间减一 .

对于 a 中,在区间 [x,y-1] 的值对前面位置会开始有影响,操作为区间加一 .

因为证明了最有位置最小值递增,所以每个 b_i 都能取到最小值 .

所以,只需要区间修改,全局最大值即可 .

时间复杂度 : O(n\log n)

空间复杂度 : O(n)

#include<bits/stdc++.h>
using namespace std;
inline int read(){
    char ch=getchar();
    while(ch<'0'||ch>'9')ch=getchar();
    int res=0;
    while(ch>='0'&&ch<='9'){
        res=(res<<3)+(res<<1)+ch-'0';
        ch=getchar();
    }
    return res;
}
inline void print(long long res){
    if(res==0){
        putchar('0');
        return;
    }
    int a[20],len=0;
    while(res>0){
        a[len++]=res%10;
        res/=10;
    }
    for(int i=len-1;i>=0;i--)
        putchar(a[i]+'0');
}
int t;
int n,m;
int a[1000010],b[1000010];
class node{public:int tag,val;}ts[1000010<<2];
int P[1000010];
int bit[1000010],c[1000010],cnt=0;
inline void pd(int x){  
    if(ts[x].tag==0)return;
    ts[x<<1].val+=ts[x].tag;
    ts[x<<1].tag+=ts[x].tag;
    ts[x<<1|1].val+=ts[x].tag;
    ts[x<<1|1].tag+=ts[x].tag;
    ts[x].tag=0;
}
void build(int x,int l,int r){
    if(l==r){
        ts[x].val=ts[x].tag=0;
        return;
    }
    int mid=(l+r)>>1;
    build(x<<1,l,mid);
    build(x<<1|1,mid+1,r);
    ts[x].val=ts[x].tag=0;
}
void upd(int x,int l,int r,int cl,int cr,int val){
    if(l==cl&&r==cr){
        if(l!=r)pd(x);
        ts[x].val+=val;
        ts[x].tag+=val;
        return;
    }
    pd(x);
    int mid=(l+r)>>1;
    if(cr<=mid)upd(x<<1,l,mid,cl,cr,val);
    else if(mid+1<=cl)upd(x<<1|1,mid+1,r,cl,cr,val);
    else{
        upd(x<<1,l,mid,cl,mid,val);
        upd(x<<1|1,mid+1,r,mid+1,cr,val);
    }
    ts[x].val=min(ts[x<<1].val,ts[x<<1|1].val);
}
int qry(int x,int l,int r,int pos){
    if(l==r)return ts[x].val;
    pd(x);
    int mid=(l+r)>>1;
    if(pos<=mid)return qry(x<<1,l,mid,pos);
    else return qry(x<<1|1,mid+1,r,pos);
}
int upd(int i){
    while(i){
        bit[i]++;
        i-=i&-i;
    }
}
int qry(int i){
    int res=0;
    while(i<=n){
        res+=bit[i];
        i+=i&-i;
    }
    return res;
}
void init(){
    for(int i=0;i<=n;i++)bit[i]=0;
}
void solve(){
    init();
    n=read();m=read();
    for(int i=0;i<n;i++)a[i]=read();
    for(int i=0;i<m;i++)b[i]=read();
    sort(b,b+m);
    vector<pair<int,int> >v;
    for(int i=0;i<n;i++)v.push_back(make_pair(a[i],i));
    sort(v.begin(),v.end());
    build(1,0,n);
    for(int i=0;i<(int)v.size();i++){
        int pos=v[i].second;
        upd(1,0,n,pos+1,n,1);
    }
    int p1=0,p2=0;
    while(p1<(int)v.size()&&v[p1].first<b[0]){
        upd(1,0,n,0,v[p1].second,1);
        p1++;
    } 
    while(p2<(int)v.size()&&v[p2].first<=b[0]){
        upd(1,0,n,v[p2].second+1,n,-1);
        p2++;
    }
    long long ans=0;
    ans+=ts[1].val;
    for(int i=1;i<m;i++){
        while(p1<(int)v.size()&&v[p1].first<=b[i]-1){
            upd(1,0,n,0,v[p1].second,1);
            p1++;
        }
        while(p2<(int)v.size()&&v[p2].first<=b[i]){
            upd(1,0,n,v[p2].second+1,n,-1);
            p2++;
        }
        ans+=ts[1].val;
    }
    cnt=0;
    for(int i=0;i<(int)v.size();i++)c[cnt++]=v[i].first;
    sort(c,c+cnt);
    cnt=unique(c,c+cnt)-c;
    for(int i=0;i<n;i++){
        int id=lower_bound(c,c+cnt,a[i])-c;
        ans+=qry(id+2);
        upd(id+1);
    }
    print(ans);
    putchar('\n');
}
int main(){
    t=read();
    while(t--){
        solve();
    }
    return 0;
}
/*inline? ll or int? size? min max?*/
/*
1
3 3
3 2 1
1 2 3
*/

sol2

考虑分治 .

因为每个 $b$ 都能取到最优的最小位置 . 令 $mid=\lfloor \frac{l_2+r_2}{2}\rfloor$ . 考虑 $b_{mid}$ 在区间 $a$ 的 $[l_1,r_1]$ 中的位置,可以 $O(r_1-l_1)$ 地计算出 ,钦定为位置 $p$ . 那么,分治成 $solve(l_1,p,l_2,mid)$ 和 $solve(p,r_1,mid,r_2)$ . 递归 $\log n$ 层 . 时间复杂度 : $O(n\log n)

空间复杂度 : O(n)

#include<bits/stdc++.h>
using namespace std;
inline int read(){
    char ch=getchar();
    while(ch<'0'||ch>'9')ch=getchar();
    int res=0;
    while(ch>='0'&&ch<='9'){
        res=(res<<3)+(res<<1)+ch-'0';
        ch=getchar();
    }
    return res;
}
inline void print(long long res){
    if(res==0){
        putchar('0');
        return;
    }
    int a[20],len=0;
    while(res>0){
        a[len++]=res%10;
        res/=10;
    }
    for(int i=len-1;i>=0;i--)
        putchar(a[i]+'0');
}
const int inf=1e9+10;
int t;
int n,m;
int a[1000010],b[1000010];
vector<int>v[1000010];
int cnt1[1000010],cnt2[1000010];
int bit[2000010],c[2000010],cnt=0;
void init(){
    for(int i=0;i<=n;i++)v[i].clear();
    for(int i=0;i<=cnt;i++)bit[i]=0;
    cnt=0;
}
void solve(int lb,int rb,int la,int ra){
    for(int i=la;i<=ra;i++)cnt1[i]=cnt2[i]=0;
    int mid=(lb+rb)>>1;
    for(int i=la;i<=ra-1;i++){
        if(a[i]>b[mid])cnt1[i+1]++;
        if(a[i]<b[mid])cnt2[i]++;
    }
    for(int i=ra-1;i>=la;i--)cnt2[i]+=cnt2[i+1];
    for(int i=la+1;i<=ra;i++)cnt1[i]+=cnt1[i-1];
    int mn=inf;
    for(int i=la;i<=ra;i++)mn=min(mn,cnt1[i]+cnt2[i]);
    for(int i=la;i<=ra;i++){
        if(cnt1[i]+cnt2[i]==mn){
            v[i].push_back(b[mid]);
            if(lb<=mid-1)solve(lb,mid-1,la,i);
            if(mid+1<=rb)solve(mid+1,rb,i,ra);
            return;
        }
    }
}
int upd(int i){
    while(i){
        bit[i]++;
        i-=i&-i;
    }
}
int qry(int i){
    int res=0;
    while(i<=cnt){
        res+=bit[i];
        i+=i&-i;
    }
    return res;
}
void solve(){
    init();
    n=read();m=read();
    for(int i=0;i<n;i++)a[i]=read();
    for(int i=0;i<m;i++)b[i]=read();
    sort(b,b+m);
    solve(0,m-1,0,n);
    for(int i=0;i<n;i++)c[cnt++]=a[i];
    for(int i=0;i<m;i++)c[cnt++]=b[i];
    sort(c,c+cnt);
    cnt=unique(c,c+cnt)-c;
    long long ans=0;
    for(int i=0;i<=n;i++){
        for(int j=0;j<(int)v[i].size();j++){
            int id=lower_bound(c,c+cnt,v[i][j])-c;
            ans+=qry(id+2);
        }
        for(int j=0;j<(int)v[i].size();j++){
            int id=lower_bound(c,c+cnt,v[i][j])-c;
            upd(id+1);
        }
        if(i<n){
            int id=lower_bound(c,c+cnt,a[i])-c;
            ans+=qry(id+2);
            upd(id+1);
        }
    }
    print(ans);
    putchar('\n');
}
int main(){
    t=read();
    while(t--){
        solve();
    }
    return 0;
}
/*inline? ll or int? size? min max?*/
/*
1
5 8
1 3 2 4 3
8 2 3 4 2 3 2 1
*/

写在最后,也可以去我的cnblog看一看哦 link