题解:P10139 [USACO24JAN] Nap Sort G

· · 题解

提供一个瓶颈是基数排序的做法

首先对原数组排序,方便考虑。

考虑先枚举 Bessie 堆的大小 m

如果 a_n 是助手堆的,那么答案就是 a_n;否则答案是 \frac{m\cdot(m+1)}{2}。因此强制让 a_n 放入 Bessie 堆。

那么现在的问题就是 check m 是否存在包含 a_n 的合法方案。

在一个方案中,Bessie 堆的数把原序列划分成若干段,那么原序列中助手堆的数,会有对应的合法值域区间,每个助手堆中的数必须在相应区间中,方案才能合法。

因此我们可以 O(n) 从后往前分段。假设枚举到 a_i,如果 a_i 是在当前合法值域区间里,就把 a_i 分到这一段,否则 a_i 就设为这一段的左端点(即把 a_i 放入 Bessie 堆中),然后考虑 a_{i-1}。放完最后一个 Bessie 数后,再 check 序列前面剩下的一些数是否在合法区间中,即可。这样做是 O(n^2) 的。

这样做其实非常麻烦,因为我们只关心是否存在方案,而不关心具体方案。

考虑不合法的情况长什么样。已经选了 a_n,考虑 a_{n-1},如果 a_{n-1}<\sum\limits_{j=1}^{m}j,这就意味着 a_1,a_2,\dots,a_{n-1} 是一定存在方案的,不用再往前分段考虑了。

否则,我们就必须选 a_{n-1},然后继续依次往前考虑 a_{n-x},直到 a_{n-x}<\sum\limits_{j=x}^{m}j

也就是说,m 合法,当且仅当存在 i\in[n-m,n-1],使得 a_{i}<\sum\limits_{j=n-i}^{m}j。于是就有了更简便的 O(n^2) 写法。

#include <bits/stdc++.h>
using namespace std;
using LL = long long;

const int N = 2e5 + 5;

int n;
LL a[N], b[N];

LL sum(int l, int r) {
    return b[r] - b[l - 1];
}

void solve() {
    read(n);
    for (int i = 1; i <= n; i++) {
        read(a[i]);
        b[i] = b[i - 1] + i;
    }
    sort();
    LL ans = a[n];
    for (int i = 1; i <= n; i++) {
        for (int j = n; j >= n - i + 1; j--) {
            if (a[j - 1] < sum(n - j + 1, i)) {
                ans = min(ans, sum(1, i));
            }
        }
    }
    printf("%lld\n", ans);
}

继续考虑优化,发现 i,j 可以交换枚举顺序,然后二分,即可做到 O(n\log n)

#include <bits/stdc++.h>
using namespace std;
using LL = long long;

const int N = 2e5 + 5;

int n;
LL a[N], b[N];

LL sum(int l, int r) {
    return b[r] - b[l - 1];
}

void solve() {
    read(n);
    for (int i = 1; i <= n; i++) {
        read(a[i]);
        b[i] = b[i - 1] + i;
    }
    sort();
    LL ans = a[n];
    for (int i = 1; i <= n; i++) {
        int len = n - i + 1, l = len, r = n, mid;
        while (l < r) {
            mid = l + r >> 1;
            if (a[i - 1] < sum(len, mid)) r = mid;
            else l = mid + 1;
        }
        if (a[i - 1] < sum(len, r)) {
            ans = min(ans, sum(1, r));
        }
    }
    printf("%lld\n", ans);
}

更快的做法就是,用初中二次函数相关知识,推一下式子,做到 O(1) 求,那么复杂度瓶颈就是 sort 了,求答案是 O(n) 的。然后我选择 256 进制的基数排序,即可做到非常接近线性。

完整代码:

提交记录

#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using DB = double;

namespace FIO {
    char buf[1 << 20], *_now = buf, *_end = buf;
    inline char _getchar() {
        return _now == _end && (_end = (_now = buf) + fread(buf, 1, 1 << 20, stdin), _now == _end) ? EOF : *_now++;
    }

    template <typename T>
    void read(T &x) {
        x = 0;
        char c = getchar();
        for (; c < '0' || c > '9'; c = getchar());
        for (; c >= '0' && c <= '9'; x = (x << 3) + (x << 1) + (c ^ 48), c = getchar());
    }
} using FIO::read;

const int N = 2e5 + 5;

int n;
LL a[N], b[N];

LL sum(int l, int r) {
    return b[r] - b[l - 1];
}

// 基数排序
int cnt[260];
LL tmp[N];
void sort() {
    for (int i = 0; i <= 40; i += 8) {
        memset(cnt, 0, sizeof(cnt));
        for (int j = 1; j <= n; j++) cnt[(a[j] >> i) & 255]++;
        for (int j = 1; j < 256; j++) cnt[j] += cnt[j - 1];
        for (int j = n; j >= 1; j--) tmp[cnt[(a[j] >> i) & 255]--] = a[j];
        for (int j = 1; j <= n; j++) a[j] = tmp[j];
    }
}

void solve() {
    read(n);
    for (int i = 1; i <= n; i++) {
        read(a[i]);
        b[i] = b[i - 1] + i;
    }
    sort();
    LL ans = a[n];
    for (int i = 1; i <= n; i++) {
        int len = n - i + 1;
        LL x = a[i - 1] + a[i - 1] + (LL)len * len - len;
        LL dlt = x + x + x + x + 1;
        if (dlt < 0) {
            ans = min(ans, sum(1, len));
            continue;
        }
        LL k = floor((-1 + sqrt(dlt)) / 2 + 1);
        if (k <= n) {
            ans = min(ans, sum(1, max((int)k, len)));
        }
    }
    printf("%lld\n", ans);
}

int main() {
    int T;
    read(T);
    while (T--) solve();
    return 0;
}