题解: AT_abc391_f [ABC391F] K-th Largest Triplet

· · 题解

题解: AT_abc391_f [ABC391F] K-th Largest Triplet

水题一道,只是有点思维难度而已了。

1. 解题思路

ABC 分别按降序排序。同时,定义函数 f(i, j, k) = A_i \times B_j + B_j \times C_k + C_k \times A_i

我们需要尝试枚举前 K 个最大的值,其中 K 的范围较小,K \le 2 \times 10^5。重要的是:

这意味着如果我们按降序枚举,f(i, j, k) 总是先于 f(i+1, j, k)f(i, j+1, k)f(i, j, k+1)

因此,可以通过以下算法解决该问题:

  1. 准备一个二叉堆 Q,将 \{f(0,0,0), 0, 0, 0\} 插入二叉堆中(这里我们可以用 vector 容器从 0 号下标开始存储)。
  2. 重复以下步骤 K 次:

    • 取出 Q 中的最大元素 \{val, i, j, k\}
    • 移除 \{val, i, j, k\},并把 \{f(i+1, j, k), i+1, j, k\}\{f(i, j+1, k), i, j+1, k\}\{f(i, j, k+1), i, j, k+1\} 插入到二叉堆中(如果他们尚未在 Q 内)。
  3. 时间复杂度为 O(N \log N + K \log K)

2. 代码实现

就知道你们想要这个。

#include <bits/stdc++.h>
using namespace std;
#define int long long
struct Sum {
    int value;
    int i, j, k;
    Sum(int v, int x, int y, int z) : value(v), i(x), j(y), k(z) {}
    bool operator < (const Sum &other) const {
        return value < other.value;
    }
};
int find_kth(int N, int K, vector<int> &a, vector<int> &b, vector<int> &c) {
    sort(a.rbegin(), a.rend());
    sort(b.rbegin(), b.rend());
    sort(c.rbegin(), c.rend());
    priority_queue<Sum> pq;
    pq.push(Sum((int)a[0] * b[0] + (int)b[0] * c[0] + (int)c[0] * a[0], 0, 0, 0));
    set<tuple<int, int, int>> st;
    st.insert(make_tuple(0, 0, 0));
    int count = 0;
    while (!pq.empty()) {
        Sum cur = pq.top();
        pq.pop();
        count++;
        if (count == K) return cur.value; 
        if (cur.i + 1 < N && st.find(make_tuple(cur.i + 1, cur.j, cur.k)) == st.end()) {
            st.insert(make_tuple(cur.i + 1, cur.j, cur.k));
            pq.push(Sum((int)a[cur.i + 1] * b[cur.j] + (int)b[cur.j] * c[cur.k] + (int)c[cur.k] * a[cur.i + 1], cur.i + 1, cur.j, cur.k));
        }
        if (cur.j + 1 < N && st.find(make_tuple(cur.i, cur.j + 1, cur.k)) == st.end()) {
            st.insert(make_tuple(cur.i, cur.j + 1, cur.k));
            pq.push(Sum((int)a[cur.i] * b[cur.j + 1] + (int)b[cur.j + 1] * c[cur.k] + (int)c[cur.k] * a[cur.i], cur.i, cur.j + 1, cur.k));
        }
        if (cur.k + 1 < N && st.find(make_tuple(cur.i, cur.j, cur.k + 1)) == st.end()) {
            st.insert(make_tuple(cur.i, cur.j, cur.k + 1));
            pq.push(Sum((int)a[cur.i] * b[cur.j] + (int)b[cur.j] * c[cur.k + 1] + (int)c[cur.k + 1] * a[cur.i], cur.i, cur.j, cur.k + 1));
        }
    }
    return -1;
}
signed main() {
    int N, K; cin >> N >> K;
    vector<int> a(N), b(N), c(N);
    for (int i = 1; i <= N; i++) {
        int x; cin >> x;
        a[i - 1] = x;
    }
    for (int i = 1; i <= N; i++) {
        int x; cin >> x;
        b[i - 1] = x;
    }
    for (int i = 1; i <= N; i++) {
        int x; cin >> x;
        c[i - 1] = x;
    }
    cout << find_kth(N, K, a, b, c) << endl;
    return 0;
}

其中,我使用了 set 容器来判断是否已经在二叉堆 Q 中。

测评记录