题解:P11048 [蓝桥杯 2024 省 Java B] 拼十字

· · 题解

思路

对于第 i 个矩形,问存在多少个矩形 j 满足:

c_j \neq c_i, \\ l_j > l_i, \\ w_j < w_i

我们可以使用 CDQ 分治的思想求一个二维偏序,首先按 l 为第一关键字排序,再按 w 做第二关键字排序。但是注意到 l_i = l_j 时会出现计算不正确的情况,并且去重不太好处理,因此考虑容斥。

首先对于每个 i 计算出满足:

c_j \neq c_i, \\ l_j \geq l_i, \\ w_j < w_i

的数量,再减去满足:

c_j \neq c_i, \\ l_j = l_i, \\ w_j < w_i

的数量即可。

代码

#include <iostream>
#include <algorithm>
using namespace std;
const int N = 100010, MOD = 1000000007;
int trval[3][N] = { 0 };
inline int lowbit(int i) { return i & (-i); }
inline void update(int tr[], int i, int val) {
    while (i < N) tr[i] += val, i += lowbit(i);
}
inline int query(int tr[], int i) {
    int ret = 0;
    while (i) ret = (ret + tr[i]) % MOD, i -= lowbit(i);
    return ret;
}
struct Ele {
    int l, w, c;
    bool operator< (const Ele& other) const {
        if (l != other.l) return l > other.l;
        return w < other.w;
    }
} ele[N], tmp[N];
int ans = 0;
void merge_sort(int l, int r) {
    if (l >= r) return;
    int mid = l + r >> 1;
    merge_sort(l, mid), merge_sort(mid + 1, r);
    int i = l, j = mid + 1, k = 0;
    while (i <= mid && j <= r)
        if (ele[i].w < ele[j].w) 
            update(trval[ele[i].c], ele[i].w, 1), tmp[k++] = ele[i++];
        else {
            for (int c = 0; c < 3; c++)
                if (c != ele[j].c)
                    ans = (ans + query(trval[c], ele[j].w - 1)) % MOD;
            tmp[k++] = ele[j++];
        }
    while(i <= mid)
        update(trval[ele[i].c], ele[i].w, 1), tmp[k++] = ele[i++];
    while (j <= r) {
        for (int c = 0; c < 3; c++)
            if (c != ele[j].c)
                ans = (ans + query(trval[c], ele[j].w - 1)) % MOD;
        tmp[k++] = ele[j++];
    }
    for (i = l; i <= mid; i++) update(trval[ele[i].c], ele[i].w, -1);
    for (i = l, j = 0; j < k; i++, j++) ele[i] = tmp[j];
}
void sub(int n) {
    int las = -1;
    for (int i = 1; i <= n; i++) {
        if (ele[i].l != las) {
            int cnt = 0;
            for (int j = i; j <= n; j++)
                if (ele[j].l == ele[i].l) {
                    update(trval[ele[j].c], ele[j].w, 1);
                    cnt++;
                }
                else break;
            for (int j = i, k = 0; k < cnt; j++, k++)
                for (int c = 0; c < 3; c++)
                    if (c != ele[j].c)
                        ans = (ans - query(trval[c], ele[j].w - 1) + MOD) % MOD;
            for (int j = i, k = 0; k < cnt; j++, k++)
                update(trval[ele[j].c], ele[j].w, -1);
        }
        las = ele[i].l;
    }
}
int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) 
        scanf("%d%d%d", &ele[i].l, &ele[i].w, &ele[i].c);
    sort(ele + 1, ele + 1 + n);
    merge_sort(1, n);
    sort(ele + 1, ele + 1 + n);
    sub(n);
    printf("%d", ans);
    return 0;
}