题解:P16461 [UOI 2026] Lazy Student

· · 题解

洛谷首 A。

从第 1 列开始考虑,我们每次要当前列的元素和尽可能小(假设当前考虑到第 i 列),那么每次操作就希望交换过来的值与原来的值相差越大越好。假设我们要把 a_{x,i} 换掉,那么肯定将选择 \min_{j=i}^m a_{x,j} 换过来。也就是说,每次操作我们需要找到最大的 a_{x,i}-\min_{j=i}^m a_{x,j}

那么如果同时有很多个最大的 a_{x,i}-\min_{j=i}^m a_{x,j} 相等呢?我们希望这次操作对后续其他列的操作影响尽可能小,容易发现也就是要被交换的列尽可能靠后,也就是令 p_x 为满足 a_{x,p_x}=\min_{j=i}^m a_{x,j} 的最大下标,我们需要在保证 a_{x,i}-\min_{j=i}^m a_{x,j} 最大的情况下,p_x 也尽可能大。

那么根据上述贪心模拟即可,用线段树维护后缀最小值,逐列来做会好写一点。

::::info[Code]

#include <bits/stdc++.h>
using namespace std;
int a[1005][1005], aa[1005];
long long s[1005];

struct tree {
    pair<int, int>t[4005];
    void bu(int d, int l, int r) {
        if (l == r) {
            t[d].first = aa[l], t[d].second = l;
            return;
        }
        int mid = (l + r) / 2;
        bu(d * 2, l, mid), bu(d * 2 + 1, mid + 1, r);
        t[d] = min(t[d * 2], t[d * 2 + 1]);
    }
    void add(int d, int l, int r, int k, int z) {
        if (l == r) {
            t[d].first = z;
            t[d].second = k;
            return;
        }
        int mid = (l + r) / 2;
        if (k <= mid)
            add(d * 2, l, mid, k, z);
        else
            add(d * 2 + 1, mid + 1, r, k, z);
        t[d] = min(t[d * 2], t[d * 2 + 1]);
    }
    pair<int, int>ask(int d, int l, int r, int ll, int rr) {
        if (ll <= l && r <= rr)
            return t[d];
        int mid = (l + r) / 2;
        pair<int, int>ans;
        ans.first = 2e9, ans.second = 0;
        if (ll <= mid)
            ans = min(ans, ask(d * 2, l, mid, ll, rr));
        if (rr > mid)
            ans = min(ans, ask(d * 2 + 1, mid + 1, r, ll, rr));
        return ans;
    }
} tr[1005];

struct node {
    int cha, p, h;
    friend bool operator<(node x, node y) {
        if (x.cha != y.cha)
            return x.cha > y.cha;
        return x.p > y.p;
    }
};
node b[1000005];

int main() {
    int n, m, k;
    cin >> n >> m >> k;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            cin >> a[i][j];
            s[j] += a[i][j];
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++)
            aa[j] = a[i][j];
        tr[i].bu(1, 1, m);
    }
    for (int i = 1; i <= m; i++) {
        if (k <= 0)
            break;
        int cnt = 0;
        for (int j = 1; j <= n; j++) {
            auto x = tr[j].ask(1, 1, m, i, m);
            if (a[j][i] - x.first > 0)
                b[++cnt] = (node) {
                a[j][i] - x.first, x.second, j
            };
        }
        sort(b + 1, b + cnt + 1);
        for (int j = 1; j <= cnt; j++) {
            auto x = b[j];
            if (!k)
                break;
            s[x.p] += x.cha, s[i] -= x.cha;
            swap(a[x.h][i], a[x.h][x.p]);
            tr[x.h].add(1, 1, m, i, a[x.h][i]);
            tr[x.h].add(1, 1, m, x.p, a[x.h][x.p]);
            k--;
        }
    }
    for (int i = 1; i <= m; i++)
        cout << s[i] << " ";
}

::::