P10235 [yLCPC2024] C. 舞萌基本练习 题解

· · 题解

题目传送门

大致题意:

多组测试数据,每组数据给定一个长度为 n 的序列和一个参数 k,要求将此区间划分成不超过 k 段,使这些区间中的逆序对数量的最大值最小。

思路:

对于求“最大值最小”这类问题,很容易想到二分。显然,本题的答案是具有单调性的,即当划分的段数减少时,区间中逆序对数量的最大值只会增大不会减小

所以我们考虑将二分答案转化为二分判定,具体为:我们二分一个 limit 值,表示当前的解,然后将当前这个 limit 代入计算。

若当前解合法,则说明区间 [limit,r] 的解肯定都是合法的,因为此时 limit 也可能作为最后答案,所以令 r = mid,向左扩展答案即可。

若当前解不合法,则说明区间 [l,limit] 的解肯定都不合法,所以此时令 l = mid + 1 向右扩展答案即可。

二分的问题解决了,那么怎么判断这个解是否合法呢?

一个很简单的想法:扫描整个序列,一开始整个序列只有一段,将序列中的数一个一个地加入这个段中,当此段逆序对数量大于 limit 时,就重新开辟新的一段并使段数 cnt 增加 1。将整个序列划分完后,若 cnt \le k,则此解合法,否则不合法。

对于求逆序对数量,用树状数组可以很好解决。

建立一个权值树状数组,每次在某段末尾加入一个数时,只需计算该段中大于它的数的个数,这就是新增的逆序对数。

然而这里要注意两个点:

  1. 在重新开辟一段时,之前那段的数要全部从树状数组中抹去,这样才能让它正确地求出后面段的逆序对数。
由于长度为 $n$ 的序列最大逆序对数量为 $\frac{n(n-1)}{2}$,所以我的二分上界取了 $10^{10}$。 在每次二分中,序列中的所有数只会进入一次、出一次树状数组,所以 $\operatorname{check()}$ 的时间复杂度为 $O(n\log n)$。 整个程序时间复杂度 $O(n\log n\log 10^{10})$,空间复杂度 $O(n)$。 $\texttt{Code}:
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstdio>

#define lowbit(x) x & -x
using namespace std;

const int N = 100010;

int T, n, k;
int a[N], c[N];
int nums[N];
int tt;
int fnd[N];

int find(int x) {
    return lower_bound(nums + 1, nums + tt + 1, x) - nums;
}

int ask(int x) {
    int res = 0;
    for(; x; x -= lowbit(x)) res += c[x];
    return res;
}

void add(int x, int y) {
    for(; x <= n; x += lowbit(x)) c[x] += y;
}

bool check(long long limit) {
    int cnt = 1; //段数 
    long long f = 0; //目前处理的段的逆序对数 
    int L = 1; //目前处理的段的左端点 
    for(int i = 1; i <= n; i++) {
        int tmp = ask(tt) - ask(fnd[i]); //计算新增的逆序对数 
        if(f + tmp > limit) {
            cnt++; //段数 + 1
            f = 0; //重置逆序对数
            for(int j = L; j <= i - 1; j++) 
                add(fnd[j], -1); //清除上一区间的贡献 
            L = i; //更新左端点 
        }
        else f += tmp;
        add(fnd[i], 1); //加入树状数组
    }
    for(int i = L; i <= n; i++) add(fnd[i], -1); //不要忘了最后一段也要抹去
    return cnt > k;
}

int main() {
    scanf("%d", &T);
    while(T--) {
        scanf("%d%d", &n, &k);
        for(int i = 1; i <= n; i++) {
            scanf("%d", &a[i]);
            nums[++tt] = a[i];
        }
        sort(nums + 1, nums + tt + 1);
        tt = unique(nums + 1, nums + tt + 1) - nums - 1;
        for(int i = 1; i <= n; i++) fnd[i] = find(a[i]);
        long long l = 0, r = 1e10;
        while(l < r) {
            long long mid = l + r >> 1;
            if(check(mid)) l = mid + 1;
            else r = mid;
        }
        printf("%lld\n", l);
        for(int i = 1; i <= tt; i++) nums[i] = 0;
        tt = 0;
    }
    return 0;
}