P7828 [CCO2021] Swap Swap Sort 题解

· · 题解

先抽象一下题意,对于一种排列,要求的最小交换次数其实就是重赋值后的逆序对个数,比如初始排列就是求正常的逆序对数。

我们不可能每次修改排列都重新求一遍逆序对,那么可以尝试考虑一下单次修改后逆序对是如何变化的。

不难发现,由于修改排列是交换相邻的两个数,这意味着原序列中只有两种值的关系发生变化,逆序对个数只会受这两种值影响。

具体地说,设 (x,y) 表示序列中 x 值在前 y 值在后的数对的个数,交换这两个数在排列中的位置后,若在交换前 xy 靠前,逆序对 ans=ans-(x,y)+(y,x)

容易发现对于任意 x\ne y,均有 (x,y)+(y,x)=c_xc_y,其中 c_kk 在原序列中的出现次数。因为 (x,y)+(y,x) 是所有由 x,y 组成的数对的个数,其等于 c_xc_y

(y,x)=c_xc_y-(x,y) 代入原式,则有 ans=ans+c_xc_y-2(x,y),于是问题变成了求 (x,y)

尝试根号分治,设阈值为 S

对于每一次要交换的 x,y,若 c_x<Sc_y<S,发现可以预处理出每一个值在原序列中的出现位置,利用贡献的单调性使用双指针 O(S) 求出 (x,y)

c_x\ge S,显然所有大于等于 S 的数不会超过 \frac n S 种,考虑离线处理所有涉及这些数的询问。将所有 c_x\ge S 的询问的信息挂在 x 上。离线处理单个 x 时,可以 O(n) 求出所有 (x,y),y\ne x,然后遍历所有挂在 x 上的询问,可以直接对每个询问 O(1) 交付已经求得的 (x,y)

c_x< Sc_y\ge S,可以将 ans=ans+c_xc_y-2(x,y) 变换成 ans=ans-c_xc_y+2(y,x),可以将询问信息挂在 y 上,和上面的那种情况一起处理。

复杂度为 O(\frac n S\cdot n+S\cdot q),根据高中数学知识,可知 S 理论取 \frac n {\sqrt q}=100 最优。

这样一来我们便求出了每一个询问相比上一个询问答案的增值。

求出初始逆序对后累加增值即为最后答案。

#include<iostream>
#include<cmath>
#include<cstring>
#include<vector>
using namespace std;
//#define int long long
inline int read(){
    int i=getchar(),r=0;
    while(i<'0'||i>'9')i=getchar();
    while(i>='0'&&i<='9')r=(r<<1)+(r<<3)+(i^48),i=getchar();
    return r;
}
const int N=100100,S=100;
int n,a[N],b[N],q,k;
long long cnt[N];
vector<int>pos[N];
struct offline{
    long long id,y,sign;
};
vector<offline>t[N];
long long ans[N*10];
long long bta[N];
inline void Add(int i,int v){while(i<=k)bta[i]+=v,i+=i&-i;}
inline int Get(int i){long long r=0;while(i)r+=bta[i],i-=i&-i;return r;}
void init(){
    cin>>n>>k>>q;
    for(int i=1;i<=k;i++)b[i]=i;
    for(int i=1;i<=n;i++){
        a[i]=read();
        pos[a[i]].emplace_back(i);
        cnt[a[i]]++;
    }
    for(int i=1;i<=n;i++){
        Add(a[i],1);
        ans[0]+=i-Get(a[i]);
    }
}
inline long long get_f(int x,int y){//前x后y的对数
    long long res=0;
    int j=0;
    for(int i=0;i<cnt[x];i++){
        while(j<cnt[y]&&pos[y][j]<=pos[x][i])j++;
        res+=cnt[y]-j;
    }
    return res;
}
long long f[N];
signed main(){
//    freopen("read.in","r",stdin);
    init();
    for(int Q=1;Q<=q;Q++){
        int x=read();
        int y=x+1;
        swap(b[x],b[y]);
        if(cnt[b[x]]<S&&cnt[b[y]]<S)ans[Q]+=cnt[b[x]]*cnt[b[y]]-2*get_f(b[x],b[y]);
        else{
            if(cnt[b[x]]>=S)ans[Q]+=cnt[b[x]]*cnt[b[y]],t[b[x]].emplace_back((offline){Q,b[y],-2});
            else ans[Q]-=cnt[b[x]]*cnt[b[y]],t[b[y]].emplace_back((offline){Q,b[x],2});
        }
    }
    for(int i=1;i<=k;i++){
        if(t[i].empty())continue;
        memset(f,0,sizeof(f));
        long long c=0;
        for(int j=1;j<=n;j++){
            c+=(a[j]==i);
            f[a[j]]+=c;
        }
        int siz=(int)t[i].size();
        for(int j=0;j<siz;j++)ans[t[i][j].id]+=t[i][j].sign*f[t[i][j].y];
    }
    for(int i=1;i<=q;i++)ans[i]+=ans[i-1],printf("%lld\n",ans[i]);
    return 0;
}