题解:AT_arc101_b [ABC107D] Median of Medians

· · 题解

我敢打赌,这肯定是洛谷这道题里最详细的题解了。

思路:

写这道题的时候可以先去 P3031 看一下,这篇题解就以此为突破口。

对于 P3031 这道题,我们可以先打暴力。

P3031 の 40 分解法:

我们学过,如果想要找子串,我们可以枚举左端点和右端点,如下:

    for(int l=1;l<=n;l++){
        for(int r=l;r<=n;r++){
            /*这里写中位数判断*/
        }
    }

于是我们这个暴力只需要注意判断子串中位数是否大于等于 x

我们将中位数判断部分记为函数 check(l,r)。那么我们改怎么写 check(l,r) 呢?

很明显,作为暴力,我们可以再用一个数组来存储 a_l\sim a_r,记为 b,长度为 m。我们需要判断的便是 b_{\lfloor \frac{m}{2}\rfloor+1}\ge x。中位数的定义如下:

b 按升序排序得到数列 b'。此时,b' 的第 M\ /\ 2\ +\ 1 个元素的值即为 b 的中位数。这里,/ 表示向下取整的除法。

定义来源于题目。

判断完 b_{\lfloor \frac{m}{2}\rfloor+1}\ge x 后,如果满足就加 1,对于每个子串都来一遍,输出即可。

#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,x;
int a[1000005];
int b[1000005];
int check(int l,int r){
    int m=0;
    for(int i=l;i<=r;i++){
        b[++m]=a[i];
    }
    sort(b+1,b+m+1);
    return (b[m/2+1]>=x?1:0);
}
signed main(){
    cin.tie(0)->sync_with_stdio(0);
    cout.tie(0)->sync_with_stdio(0);
    cin>>n>>x;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    int ans=0;
    for(int l=1;l<=n;l++){
        for(int r=l;r<=n;r++){
            ans+=check(l,r);
        }
    }
    cout<<ans;
    return 0;
}

P3031 の 80 分解法:

通过暴力得到 40 分后,我们便要向 80 分进发。

我们可以发现,对于每一个字串,如果希望 b_{\lfloor \frac{m}{2}\rfloor+1}\ge x,那么就肯定满足这个性质:

这个性质是怎么来的?我们可以想,排完序后,如果 b_{\lfloor \frac{m}{2}\rfloor+1}\ge x 满足,那么大于等于 x 的数字的数量一定大于等于 m-(\lfloor \frac{m}{2}\rfloor+1)+1,即 m-\lfloor \frac{m}{2} \rfloor。而这个数,一定满足 m-\lfloor\frac{m}{2}\rfloor \ge \lfloor\frac{m}{2}\rfloor

这里也就间接的证明了 x-y\ge0。因为由上定义可以得到 x=m-\lfloor \frac{m}{2}\rfloor,y=\lfloor \frac{m}{2}\rfloor。这里的 \lfloor \frac{m}{2} \rfloor 在代码中就是 m/2,编译器会自动向下取整。

最后便是如何计算 x-y。这个问题很简单,我们记大于等于 x 的数字为 1,小于 x 的数字为 -1。进行前缀和后,记前缀和数组为 sum,对于区间 [l,r],答案便为 sum_r-sum_{l-1}

#include<bits/stdc++.h>
#define int long long
using namespace std;
int a[1000005],sum[1000005];
int n,m,k,x;
signed main(){
    cin>>n>>x;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        if(a[i]>=x){
            sum[i]=1;
        }
        else{
            sum[i]=-1;
        }
        sum[i]+=sum[i-1];
    }
    int ans=0;
    for(int l=1;l<=n;l++){
        for(int r=l;r<=n;r++){
            if(sum[r]-sum[l-1]>=0){
                ans++;
            }
        } 
    }
    cout<<ans;
    return 0;
}

P3031 の 100 分解法:

拿下 80 分后,100 分便迎刃而解。

我们观察一下 80 分代码的 sum_r-sum_{l-1}\ge0,调整一下得:

sum_{l-1}\le sum_r

我们将其与 l-1< r 联立得:

\begin{cases} l-1< r\\ sum_{l-1}\le sum_r \end{cases}

似乎在哪里见过这个式子,我们将 l-1 替换为 ir 替换为 j,便可得到:

\begin{cases} i< j\\ sum_i\le sum_j \end{cases}

这下完全想起来了,这不就是 P1908 逆序对的颠倒版吗!

我们知道逆序对是可以用树状数组来记录之前比 a_i 大得数字,当然也可以记录之后比 a_i 小的数字。而这里需要记录的是之前小于等于 a_i 的数字。为什么?原因就是上面的式子,我不信有理解不了的。

接下来就是处理正负数的问题。很明显,数字区间是 -100000\sim100000,而树状数组处理不了小于等于 0 的数字,所以这里统一加上 100001,这样子数字区间 1\sim200001 了。

最后还要注意存入时,第一步时 sum_0,这一个不能漏了,记得还要开 long long

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int b=100001;
int a[1000005],sum[1000005];
int tr[1000005];
int n,m,k,x;
int ans[1000005];
int lowbit(int x){
    return x&-x;
}
void modify(int x,int v){
    while(x<=210001){
        tr[x]+=v;
        x+=lowbit(x);
    }
}
int query(int x){
    int ans=0;
    while(x){
        ans+=tr[x];
        x-=lowbit(x);
    }
    return ans;
}
signed main(){
    cin>>n>>x;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        if(a[i]>=x){
            sum[i]=1;
        }
        else{
            sum[i]=-1;
        }
    }
    modify(0+b,1);
    int ans=0;
    for(int i=1;i<=n;i++){
        sum[i]+=sum[i-1];
        ans+=query(sum[i]+b);
        modify(sum[i]+b,1);
    }
    cout<<ans;
    return 0;
}

AT_arc101_b の 100 分解法:

讲了这么久的 P3031,是时候该讲 AT_arc101_b 了。

我们可以发现,其实我们可以对这道题经行二分答案。为什么?我们的目标是找到一个 x,使其执行 P3031 的 100 分代码后可以满足 ans\ge\lfloor \frac{n}{2} \rfloor+1,于是我们可以找到单调性:

于是我们可以以 1\sim n 为区间,每次判断的是 a_{mid},如果 a_{mid} 符合,那么 l=mid,反之 r=mid-1

最重要的还属 check 函数。check 函数我们的代码可以沿用 P3031 的 100 分代码。我们唯独要修改的便是最后判断是否满足 ans\ge\lfloor \frac{n}{2} \rfloor+1 即可,感觉没有什么要讲的了。

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int b=100001;
int a[1000005],c[1000005],sum[1000005];
int tr[1000005];
int n,m,k;
int ans[1000005];
int lowbit(int x){
    return x&-x;
}
void modify(int x,int v){
    while(x<=210001){
        tr[x]+=v;
        x+=lowbit(x);
    }
}
int query(int x){
    int ans=0;
    while(x){
        ans+=tr[x];
        x-=lowbit(x);
    }
    return ans;
}
bool check(int x){
    memset(tr,0,sizeof tr);
    memset(sum,0,sizeof sum);
    for(int i=1;i<=n;i++){
        if(a[i]>=x){
            sum[i]=1;
        }
        else{
            sum[i]=-1;
        }
    }
    modify(0+b,1);
    int ans=0;
    for(int i=1;i<=n;i++){
        sum[i]+=sum[i-1];
        ans+=query(sum[i]+b);
        modify(sum[i]+b,1);
    }
    if(ans>=(n*(n+1)/2+1)/2) return 1;
    return 0;
}
signed main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        c[i]=a[i];
    }
    sort(c+1,c+n+1);
    int l=1,r=n;
    while(l<r){
        int mid=(l+r+1)/2;
        if(check(c[mid])) l=mid;
        else r=mid-1;
    }
    cout<<c[l];
    return 0;
}

后话:

本篇题解编辑历时 1 小时,给个赞再走呗。