题解 P4165 【[SCOI2007]组队】

· · 题解

H 代表 minhV 代表 minv

考虑满足条件的队员应满足 3 个条件:

  1. A*h+B*v\le C+A*H+B*V
  2. h\ge H
  3. v\ge V

考虑如果满足 1 和 3,则能得出 4. h \le \frac{C}{A}+H

但是根据 1 和 4 或 3 和 4,却不能推出 3 或 1,所以这段好像没啥用?

s 表示 A*h+B*v

考虑外层枚举 H,内层枚举 V,将数组按照 s 排序,那么满足第一个条件的位置会随着内层的枚举递增。

当然只有第一个条件是不够的,考虑如何增加其他的条件,因为外层枚举是固定的,所以显然可以在统计的答案的时候遵守限制 2。考虑满足第三个限制。如果只是简单的满足的话(如下

for(int i = 1; i <= n; ++i){ // H
        int r = 0, cnt = 0;
        for(int j = 1; j <= n; ++j){ // V
            while(r < n && a3[r + 1].s <= A * a1[i].h + B * a2[j].v)
                ++r, cnt += (a3[r].h >= a1[i].h && a3[r].v >= a2[j].v);
            // del 
        }
    }

那么随着 j 的增加,我们的删除是 O(n)

考虑如何做到 O(1) 的删除,考虑刚才推出的性质 4,如果满足 1 和 4 的话,的确没法推出 v\ge V,但是如果性质 4(h \le \frac{C}{A}+H 存在的话,假设此时 v<V 那么,我们能得到 A*h+B*v\le A*H+B*V+C

这有什么用呢?考虑在统计的时候并不一定满足性质 1、2、3 的答案,而是只满足 1 和 2 还有 4 的答案,然后再将多余的减掉,我们发现减的过程是单调的(如下

for(int i = 1; i <= n; ++i){ // h
        int l1 = 0, l2 = 0, cnt = 0;
        int Min = a1[i].h, Max = a1[i].h + C / A;   
        for(int j = 1; j <= n; ++j){ // v
            while(l2 < n && a3[l2 + 1].s <= C + A * a1[i].h + B * a2[j].v)
                ++l2, cnt += (a3[l2].h >= Min && a3[l2].h <= Max);
            while(l1 < n && a2[l1 + 1].v < a2[j].v)
                ++l1, cnt -= (a2[l1].h >= Min && a2[l1].h <= Max);  
            ans = max(ans, cnt);
        }
    }

可以发现这样是对的。另外,再删除过程中是不会删除没加入的点的。

证明如下:

首先,它被删除了,说明他满足不等式 a2[l1].h >= Mina2[l1].h <= Max,即 A*h\le A*H+Ch\ge H。我们知道他没有被添加,说明 A*h+B*v>A*H+B*V+C,解得 v>V,但 a2[l1 + 1].v < a2[j].v

并且也不会漏删加入的点(这个非常显然

并且在固定 i 和 j 之后,应该要被删除的点一定会在这次被删除,

分两种情况讨论:

1 应该删除的点在之后才被删除

证明:首先,它在这次被加入了,说明 A*h+B*v\le C+A*H+B*V,并且 h \le \frac{C}{A}+H。它应该被删除,说明 v<V,那么肯定会在现在被删除

2 应该删除的点在之前被删除了

证明:它被删除了,说明 h \le \frac{C}{A}+H,且 v<V,这个 V 是当时枚举的,那么在当时就有 A*h+B*v\le C+A*H+B*V,所以在当时就会被加入

所以删除满足单调性,所以,应该就没了。贴一下 AC 代码。

如有不严谨的地方,请指出

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 5010
using namespace std;

int A, B, C, n, m;

struct node{
    int h, v, s;    
}a1[maxn], a2[maxn], a3[maxn];

inline bool cmp1(node x, node y){return x.h < y.h;}
inline bool cmp2(node x, node y){return x.v < y.v;}
inline bool cmp3(node x, node y){return x.s < y.s;}

int ans;
int main(){
    scanf("%d%d%d%d", &n, &A, &B, &C);
    for(int i = 1; i <= n; ++i){
        scanf("%d%d", &a1[i].h, &a1[i].v);
        a1[i].s = A * a1[i].h + B * a1[i].v;
        a3[i] = a2[i] = a1[i];
    }
    //sort(a1 + 1, a1 + n + 1, cmp1); 
    sort(a2 + 1, a2 + n + 1, cmp2);
    sort(a3 + 1, a3 + n + 1, cmp3);
    for(int i = 1; i <= n; ++i){ // h
        int l1 = 0, l2 = 0, cnt = 0;
        int Min = a1[i].h, Max = a1[i].h + C / A;   
        for(int j = 1; j <= n; ++j){ // v
            while(l2 < n && a3[l2 + 1].s <= C + A * a1[i].h + B * a2[j].v)
                ++l2, cnt += (a3[l2].h >= Min && a3[l2].h <= Max);
            while(l1 < n && a2[l1 + 1].v < a2[j].v)
                ++l1, cnt -= (a2[l1].h >= Min && a2[l1].h <= Max);  
            ans = max(ans, cnt);
        }
    }
    cout << ans;
    return 0; 
}