P14361 题解

· · 题解

如果不考虑 \frac{n}{2} 的限制我们直接取 \max 就做完了。设这时求得的答案为 ans

如果这时满足限制直接输出 ans 就行了。

发现最多只会有一个部门不满足这个限制。

设这个不满足限制的部门是部门 1

如果不是的话那我们就直接 swap 一下就行。

我们现在要做的就是最小化这个限制带来的影响。

dif_i = \min(a_{i, 1} - a_{i, 2}, a_{i, 1} - a_{i, 3})。表示这个人原本选了 1,现在选 23 对答案的最小影响。

我们按 dif 数组把选了部门 1 的人从小到大排序,将 ans 减去最小的前 c_1-\frac{n}{2} 个即是最终答案(c_1 为选了 1 的人数)。

因为这样能保证影响最小,而前面求得的 ans 又是最大的,所以答案一定最优。

代码很好写,场上 11min 写完。当时思路不太清晰写的有点史……

#include <bits/stdc++.h>
#define r(x) for (int i = 1; i <= x; i++)
#define rep(i, a, b) for (int i = a; i <= b; i++)
using namespace std;

const int N = 2e5 + 50;
int t, n, a[N][4], s1, s2, s3;
int c1, c2, c3, dif[N], flag[N]; 

signed main()
{   
    cin.tie(0)->sync_with_stdio(0);
    cin >> t;
    while (t--)
    {
        cin >> n; int p = n / 2;
        s1 = s2 = s3 = c1 = c2 = c3 = 0; 
        r(n)
        {
            cin >> a[i][1] >> a[i][2] >> a[i][3];
            if (a[i][1] >= a[i][2] && a[i][1] >= a[i][3])
                s1 += a[i][1], c1++, flag[i] = 1;
            else if (a[i][2] >= a[i][1] && a[i][2] >= a[i][3])
                s2 += a[i][2], c2++, flag[i] = 2;
            else if (a[i][3] >= a[i][1] && a[i][3] >= a[i][2])
                s3 += a[i][3], c3++, flag[i] = 3;               
        }
        int ans = s1 + s2 + s3, cnt = 0;
        if (c1 <= p && c2 <= p && c3 <= p)
        {
            cout << ans << "\n";
            continue;
        }
        if (c2 > p)
        {
            r(n)
            {
                swap(a[i][1], a[i][2]);
                if (flag[i] == 2) flag[i] = 1;
                else if (flag[i] == 1) flag[i] = 2;
            }
            swap(s1, s2), swap(c1, c2);
        }
        if (c3 > p)
        {
            r(n) 
            {
                swap(a[i][1], a[i][3]);
                if (flag[i] == 3) flag[i] = 1;
                else if (flag[i] == 1) flag[i] = 3;
            }
            swap(s1, s3), swap(c1, c3);
        }
        r(n) if (flag[i] == 1)
            dif[++cnt] = min(a[i][1] - a[i][2], a[i][1] - a[i][3]);
        sort(dif + 1, dif + cnt + 1);
        r(c1 - p) ans -= dif[i];
        cout << ans << "\n";
    }
    return 0;
}