P8187 [USACO22FEB] Robot Instructions S题解

tzyt

2022-03-14 05:22:13

Solution

update: 2022/4/17 更正了博客地址 目录: 1. 暴力 2. 折半搜索 + map 或是哈希表 3. 折半搜索 + 双指针 + ~~和 dfs 不同~~奇怪的状态枚举方式 1. 第一种双指针 2. 第二种双指针 4. 完整代码 [题目链接](https://www.luogu.com.cn/problem/P8187) [博客](https://ttzytt.com/2022/03/P8187/)中观看体验更佳 # 1. 题意 给你 $n$ 个二维的向量,对于任意一个 $k (1 \le k \le n)$,求出有多少选取的方案能满足在这 $n$ 个向量中选 $k$ 个,并且他们的和为 $(x_g, y_g)$。 # 2. 分析 ## 2.1 暴力算法 看到这个题我们可以比较快的想到拿部分分的做法,就是暴力枚举所有的选取方案,然后看他们加起来是否等于目标向量,再把符合要求的方案累加到答案中。但是我们发现 $n = 40$,并且这个算法的复杂度是 $2^n$ 的,所以一定会超时。 ## 2.2 折半搜索 + map 或者是哈希表 ### 2.2.1 简要思路 这道题中的折半搜索指的是把 $n$ 种向量分成两部分,对这两个部分分别用暴力的方法求出所有可能的选取方案,再把这些选取的方案,以及这个方案的结果(他们的和)按照某种方案储存下来,最后匹配这两部分的方案,把符合题目要求的(和等于 $(x_g, y_g$)的)累加进答案里。 ### 2.2.2 细节 要达到前面提到的效果,我们可以使用 STL 的 map 或者是 unordered_map (手写哈希表也可以,但是可能要花更多时间来实现)来储存每个选取的方案。这里推荐使用 unordered_map,因为 map 的时间复杂度是 $O(\log n)$ 的,在这道题中会被卡,而 unodered_map 和手写哈希表的理想时间复杂度是 $O(1)$(当然使用 unodered_map 的话也要确保哈希函数写得好才不会被卡,~~比如我现在写的就过不了~~)。 对于暴力枚举状态的部分,一般的方法是 dfs ,也比较好写,这篇题解的双指针部分讲了一个比较奇怪的方法,想看的可以跳到下面。 注:下文的 map 指 unodered_map 或是 map 或是 手写的哈希表 我们首先创建两个 map ,`fir` 和 `sec`,分别储存前半部分向量的选取方案和后半部分的选取方案。对于这两个 map 的键值,我们可以设成包含 $(\text{sum}_x, \text{sum}_y, k)$ 三个整数的结构体。其中 $(\text{sum}_x, \text{sum}_y)$ 表示当前这个方案下,选取的所有向量的和。$k$ 表示当前的这个方案一共选取了多少个向量。因为可能有很多方案的 $x,y,k$ 都完全一样,所以对于 map 的值,我们设置成 $(\text{sum}_x, \text{sum}_y, k)$ 这三个值都相同的方案数。 对于答案的储存,我们开一个 `ans[n]` 的数组,`ans[i]` 表示 $k = i$ 时的选取方案数。 找到两部分的方案后我们要把符合要求的方案组合累加到答案中。具体来说,每个在 map 中储存的方案都包含一个这个方案的向量和 $(\text{sum}_x, \text{sum}_y)$ 。 我们设一个从前半部分向量得到的方案的向量和为 $\vec{v_1}$ , 一个后半部分方案的向量和为 $\vec{v_2}$ ,那么如果 $\vec{v_1} + \vec{v_2} = (x_g, y_g)$ ,我们就把这个方案记录进答案。 因为我们使用了 map ,所以并不需要真的用双层循环把每个方案都遍历一遍。我们知道对于一个可能的匹配方案 $\vec{v_1} + \vec{v_2} = (x_g, y_g)$ 。那么我们可以使用一个双层循环,一层枚举 fir 这个 map,一维枚举当前处理的 $k$ 。然后在循环里,我们可以写: `ans[当前k] += it_fir.值 * sec[{x_g - it_fir.键.x, y_g - it_fir.键.y, 当前k - it_fir.键.k}]` 其中 `it_fir` 指的是 fir 的迭代器。而 `sec[{x_g - it_fir.键.x, y_g - it_fir.键.y, 当前k - it_fir.键.k}]` 中的方案和 `it_fir` 遍历到的方案符合 $\vec{v_1} + \vec{v_2} = (x_g, y_g)$ ,并且这两种方案选取的向量数的和也等于当前 k 。因为这两个 map 的值是所有符合这些条件的方案的数量,所以我们把这两个值相乘以求出所有符合要求的匹配数量。 ## 2.3 折半搜索 + 双指针 双指针的理论复杂度似乎是和哈希表一样的,但是如果哈希函数有问题的话,哈希表的速度就会慢很多。而双指针就没有这个问题。 ### 2.3.1 双指针a 我们可以首先开两个 vector ,fir 和 sec ,其数据类型和之前 map 的一样,也是 $(\text{sum}_x, \text{sum}_y, k)$ 。fir 储存由前半部分的向量得来的方案,sec储存后半部分的。 枚举完所有方案后我们对这两个 vector 进行排序,排序规则如下: ```cpp if(a.x != b.x) return x < b.x; if(a.y != b.y) return y < b.y; return a.k < b.k; ``` 然后我们创建两个指针,$p_1$ 的初值设成 1, $p_2$ 的初值设成 `sec.size() - 1` ,代表 `sec` 的最后一个元素。此时 $p_1$ 指向的是 `fir` 中最小的元素,而 $p_2$ 指向 `sec` 中最大的元素。我们可以想一下,如果想要让当前的这两个指针所指的 $\vec{v_1}, \vec{v_2}$ 的和等于 $(x_g, y_g)$ ,需要怎么做。如果我们想要增加 $\vec{v_1} + \vec{v_2}$ 的值,只能使 $p_1$ 增加,因为 $p_2$ 已经指向这个数组里最大的元素了。反过来也是一样的,如果我们想要减少 $\vec{v_1} + \vec{v_2}$ 的值,也只能使 $p_2$ 减少。写成程序就是下面这样: ```cpp int p1 = 0, p2 = sec_half.size() - 1; while(p1 < fir_half.size() && p2 >= 0){ Instruct &f = fir_half[p1], &s = sec_half[p2]; if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){ //如果两个向量相加小于目标值,我们只能加 p1 的值, //因为 p2 指向的元素最开始就是最大的。 p1++; } else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){ //如果两个向量相加大于目标值,我们只能减 p2 的值, //因为 p1 指向的元素最开始就是最小的。 p2--; } //下面后半段代码插入的位置 } ``` 注:`Instruct` 是包含 `{x, y, k}` 三种整数的结构体。 通过这样的方法,我们最终就一定能找到 $\vec{v_1} + \vec{v_2} = (x_g, y_g)$ 的情况。不过呢,这两个数组中都可能有连续的一段是完全一样的值,也就是有多个 $p_1$ , $p_2$ 满足 $\vec{v_1} + \vec{v_2} = (x_g, y_g)$ 。因此我们需要找出符合条件的这个连续段具体是什么。通过上面的代码,我们已经知道了,满足条件的最小 $p_1$ 和最大的 $p_2$,因为我们希望找出连续段的具体范围,所以还需要找出最大的 $p_1$ 和最小的 $p_1$ 。那么如何找呢?很简单,因为连续的这一段值一定都完全相等,所以我们只需要判断当前元素是否和最开始的元素相等就可以了。 当然,因为我们还需要把符合的匹配统计进答案,而答案是按照 $k$ 来输出的。所以我们可以开两个数组 `fir_same_k` 和 `sec_same_k` 。`fir_same_k[i]` 就表示,对于第一个数组,在符合条件的这一长段中, $k = i$ 的有多少。而 `sec_same_k` 是对于第二个数组的。 然后我们就可以得到下面的代码了: 注意这段代码是插入前面那段代码的 `else if` 后面的 ```cpp else{ int p1t, p2t; memset(fir_same_k, 0, sizeof(fir_same_k)); memset(sec_same_k, 0, sizeof(sec_same_k)); //因为每次找到的符合条件的段都是不重合的,所以每次都清空一下数组 for(p1t = p1; p1t < fir_half.size() && fir_half[p1t] == f; p1t++){ //p1t 代表能满足 v_1 + v_2 == (x_g, y_g) 的最大 p1 fir_same_k[fir_half[p1t].k]++; } for(p2t = p2; p2t >= 0 && sec_half[p2t] == s; p2t--){ //p2t 代表满足 v_1 + v_2 == (x_g, y_g) 的最小 p2 sec_same_k[sec_half[p2t].k]++; } //统计答案,对于前半段和后半段都枚举可能的 for(int i = 0; i <= 20; i++){ for(int j = 0; j <= 20; j++){//这个20其实是可以改成 n / 2 + 1 的 ans[i + j] += 1LL * fir_same_k[i] * sec_same_k[j]; //相乘是因为同一个 fir_same_k[i] 和 sec_same_k[j] //中代表的任意一种选取方案都是完全相同的,(x,y,k) 都相同 } } p1 = p1t, p2 = p2t;//不加这个会一直在相同的一段死循环 } ``` 这个方法还是跑的相对比较快的,可以看下[提交记录](https://www.luogu.com.cn/record/71008837) ### 2.3.2 双指针 b 我们发现双指针 a 的方法会需要在统计答案时开 `fir_same_k` 和 `sec_same_k` 这两个数组来统计 $k$ 相同的情况。我们其实可以改进一下这个方法,直接在枚举状态的时候把 $k$ 相同的方案放到一起。 具体来说,我们把前面的 `fir` 和 `sec` 这两个 vector 改成 `vector<Instruct> fir[20], sec[20]` 。`fir[i]` 就储存前半部分 $k = i$ 时的所有方案,`sec[i]` 是后半部分的。既然把储存方案的方法改了,后面的双指针部分自然也要改。 这一次我们需要用一个双重循环来分别枚举不同的 `fir[i]` 和 `sec[j]`。在循环内部我们再做和双指针 a 相似的事情。也就是说,我们已经知道了当前前半部分的方案和后半部分的方案的 $k$,现在只需要通过双指针找出能满足 $\vec{v_1} + \vec{v_2} = (x_g, y_g)$ 的 $p_1$ 和 $p_2$ 范围。 ```cpp for(int fir_k = 0; fir_k <= n / 2 + 1; fir_k++){ //枚举当前前半部分的 k for(int sec_k = 0; sec_k <= n / 2 + 1; sec_k++){//枚举后半部分的 k int p1 = 0, p2 = sec_half[sec_k].size() - 1; while(p1 < fir_half[fir_k].size() && p2 >= 0){ Instruct &f = fir_half[fir_k][p1], &s = sec_half[sec_k][p2]; if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){ p1++; //找出当前 fir_k 时满足 v_1 + v_2 == (x_g, y_g) 的最小的 p1 } else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){ p2--; //找出当前 sec_k 时满足 v_1 + v_2 == (x_g, y_g) 的最大 p2 } else{ int p1t = p1, p2t = p2; while(p1t < fir_half[fir_k].size() && fir_half[fir_k][p1t] == f){ p1t++; } while(p2t >= 0 && sec_half[sec_k][p2t] == s){ p2t--; } ans[fir_k + sec_k] += 1LL * (p1t - p1) * (p2 - p2t); //把 p1 范围长度乘上 p2 范围的长度 //小细节:本来要表示长度的话应该是 p1t - p1 + 1的,但是我们可以观察前面两个while // 在跳出之后 p1t 会比正确的 p1t 多 1,而 p2t 会比正确的 p2t 少1,因为如果 // 它们还是正确的话会又回到循环中。因我们计算长度的时候就不需要 + 1 了。 p1 = p1t, p2 = p2t; } } } } ``` 那么双指针b的写法比a有什么好处呢?答案就是节省空间。如果我们使用的是双指针a,那储存方案的结构体必须包含 $x, y, k$ 三种整数。注意其中的 $k$ 最大只有 $20$,而我们却必须开一个 `int` 或是 `short` 来储存这个值,考虑到 $20$ 这个值非常小,不管哪种数据类型都会浪费大量的空间。而采用双指针b后,我们的结构体中只会包含 $x, y$ 两种整数, $k$ 这个值储存在数组的下标中,只要你开的数组大小为 $k$ 的最大值,就不会有任何的浪费。 具体的对比可以参考这个[提交记录](https://www.luogu.com.cn/record/71362478),可以发现相比双指针 a 的做法,双指针 b 的内存占用大约少了 17MB 左右 当然,代价也是有的,双指针 b 会稍微慢一些。我估计这主要出在双指针的环节。排序的部分甚至还会快一点。当然,不管哪种双指针,他们的理论复杂度都是一样的,因为每一种选取方案最多会被遍历到一次。 ## 2.4 奇怪的状态枚举方法 这道题中常见的状态枚举方法就是 dfs。这里提供一种比较奇怪的枚举方法。在一个选取方案中,对于每个向量,都有两种状态,选或者是不选。因为只有选或不选两种状态,我们可以想到通过二进制数字表示这个状态。数字的第 $i$ 位表示向量 $i$ 是选还是不选,例如:$(101)_2$ 就表示选择第 1 ,3 个向量,不选第 2 个。要枚举所有的状态,我们只需要把一个数从 $0$ 一直累加到 $2^n$ 就可以了,并且每次累加的时候检查他每一位是 0 还是 1 。当然,因为我们这里采用的是折半搜索,所以只需要累加到 $2^\frac{n}{2}$。 至于复杂度的话,可能比 dfs 还慢?而且码量更大?毕竟每次累加还要写一个循环把这个数字从第一位检查到第二十位。不过,因为不需要递归,所以不需要一直给递归的函数开栈,内存占用可能会少一些。 真在比赛的时候还是不建议这样写的,毕竟 dfs 写起来是真的方便,这里只是提供一种好玩的做法。 # 3. 完整代码 ## 3.1 折半搜索 + 哈希表 这个做法一直在调,现在也没卡过时间和空间限制,应该是我哈希函数写挂了,但是我也不知道正确的写法是怎么样的,如果想用这个做法,可以看其它大佬的题解。如果你会这个方法,而且能帮我调一下的话我表示非常感谢,这是[提交记录](https://www.luogu.com.cn/record/71362564),代码我放在[这里](https://www.luogu.com.cn/paste/081u2pfp)了。 ## 3.2 折半搜索 + 双指针a + 二进制枚举方法 ```cpp #include<bits/stdc++.h> using namespace std; #define ll long long #define rg register const int MAXN = 45; struct Instruct{ ll x, y; int k; const bool operator < (Instruct b) const{ if(x != b.x) return x < b.x; if(y != b.y) return y < b.y; return k < b.k; } const bool operator == (Instruct b) const{ return x == b.x && y == b.y; } }ins[MAXN]; vector<Instruct> fir_half, sec_half; ll ans[MAXN]; int mx_state; int n; int tar_x, tar_y; void vec_sum(int st, int ed, int cur_state, vector<Instruct> *cur_half){ //根据当前提供的状态 cur_state, 把选中的向量累加起来 //因为整个搜索的过程分成了两部分,所以需要参数表示是哪个部分,st, ed 表示的就是参与搜索的第一份向量,和最后一个。 ll tot_x = 0, tot_y = 0; int k = 0; int len = ed - st + 1; for(int i = 1; i <= len; i++){ if(cur_state & (1 << (i - 1))){ tot_x += ins[st + i - 1].x, tot_y += ins[st + i - 1].y; k++; } } (*cur_half).push_back({tot_x, tot_y, k}); } void state_generator(bool mode){ //mode 表示当前处理的是前半部分还是后半部分,0是前半部分,1是后半部分 rg int cur_state = 0;//初始的状态,对应的就是什么都不选 int st, ed; vector<Instruct> *cur_half;//fir_half和sec_half储存这前半部分的方案和后半部分的方案 //cur_half就是当前这次搜索要把方案存到哪里 if(mode){ cur_half = &sec_half; st = n / 2 + 1, ed = n; if(n & 1) mx_state = mx_state * 2 + 1; //mx_state就是表示状态的数字最大能达到多少,它的初始值是 2^(n/2) 但是如果 n 是奇数 //后半部分会比前半部分多包含一个向量,所以还要把原来的 mx_state * 2 + 1 } else{ cur_half = &fir_half; st = 1, ed = n / 2; } while(cur_state <= mx_state){ vec_sum(st, ed, cur_state, cur_half); cur_state++; } } int main(){ scanf("%d%d%d", &n, &tar_x, &tar_y); for(int i = 1; i <= n; i++){ scanf("%lld%lld",&ins[i].x, &ins[i].y); } for(int i = 0; i < n / 2; i++){ mx_state = mx_state | (1 << i); //最大的状态是 n/2 位都是 1 } state_generator(0); state_generator(1); sort(fir_half.begin(), fir_half.end()); sort(sec_half.begin(), sec_half.end()); rg int p1 = 0, p2 = sec_half.size() - 1; int fir_same_k[21], sec_same_k[21]; while(p1 < fir_half.size() && p2 >= 0){ Instruct &f = fir_half[p1], &s = sec_half[p2]; if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){ //如果两个向量相加小于目标值,我们只能加 p1 的值, //因为 p2 指向的元素最开始就是最大的。 p1++; } else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){ //如果两个向量相加大于目标值,我们只能减 p2 的值, //因为 p1 指向的元素最开始就是最小的。 p2--; } else{ int p1t, p2t; memset(fir_same_k, 0, sizeof(fir_same_k)); memset(sec_same_k, 0, sizeof(sec_same_k)); //因为每次找到的符合条件的段都是不重合的,所以每次都清空一下数组 for(p1t = p1; p1t < fir_half.size() && fir_half[p1t] == f; p1t++){ //p1t 代表能满足 v_1 + v_2 == (x_g, y_g) 的最大 p1 fir_same_k[fir_half[p1t].k]++; } for(p2t = p2; p2t >= 0 && sec_half[p2t] == s; p2t--){ //p2t 代表满足 v_1 + v_2 == (x_g, y_g) 的最小 p2 sec_same_k[sec_half[p2t].k]++; } //统计答案,对于前半段和后半段都枚举可能的 for(int i = 0; i <= 20; i++){ for(int j = 0; j <= 20; j++){//这个20其实是可以改成 n / 2 + 1 的 ans[i + j] += 1LL * fir_same_k[i] * sec_same_k[j]; //相乘是因为同一个 fir_same_k[i] 和 sec_same_k[j] //中代表的任意一种选取方案都是完全相同的,(x,y,k) 都相同 } } p1 = p1t, p2 = p2t;//不加这个会一直在相同的一段死循环 } } for(int i = 1; i <= n; i++){ printf("%lld\n", ans[i]); } system("pause"); } ``` ## 3.2 折半搜索 + 双指针b + 二进制枚举方法 ```cpp #include<bits/stdc++.h> using namespace std; #define ll long long #define rg register const int MAXN = 45; struct Instruct{ ll x, y; const bool operator < (Instruct b) const{ if(x != b.x) return x < b.x; return y < b.y; } const bool operator == (Instruct b) const{ return x == b.x && y == b.y; } }ins[MAXN]; vector<Instruct> fir_half[MAXN], sec_half[MAXN]; ll ans[MAXN]; int mx_state; int n; int tar_x, tar_y; void vec_sum(int st, int ed, int cur_state, vector<Instruct> *cur_half){ //根据当前提供的状态 cur_state, 把选中的向量累加起来 //因为整个搜索的过程分成了两部分,所以需要参数表示是哪个部分,st, ed 表示的就是参与搜索的第一份向量,和最后一个。 ll tot_x = 0, tot_y = 0; int k = 0; int len = ed - st + 1; for(int i = 1; i <= len; i++){ if(cur_state & (1 << (i - 1))){ tot_x += ins[st + i - 1].x, tot_y += ins[st + i - 1].y; k++; } } cur_half[k].push_back({tot_x, tot_y}); } void state_generator(bool mode){ //mode 表示当前处理的是前半部分还是后半部分,0是前半部分,1是后半部分 rg int cur_state = 0;//初始的状态,对应的就是什么都不选 int st, ed; vector<Instruct> *cur_half; if(mode){ cur_half = sec_half; st = n / 2 + 1, ed = n; if(n & 1) mx_state = mx_state * 2 + 1; //mx_state就是表示状态的数字最大能达到多少,它的初始值是 2^(n/2) 但是如果 n 是奇数 //后半部分会比前半部分多包含一个向量,所以还要把原来的 mx_state * 2 + 1 } else{ cur_half = fir_half; st = 1, ed = n / 2; } while(cur_state <= mx_state){ vec_sum(st, ed, cur_state, cur_half); cur_state++; } } int main(){ scanf("%d%d%d", &n, &tar_x, &tar_y); for(int i = 1; i <= n; i++){ scanf("%lld%lld",&ins[i].x, &ins[i].y); } for(int i = 0; i < n / 2; i++){ mx_state = mx_state | (1 << i); //最大的状态是 n/2 位都是 1 } state_generator(0); state_generator(1); for(int i = 0; i <= n / 2 + 1; i ++){ sort(fir_half[i].begin(), fir_half[i].end()); sort(sec_half[i].begin(), sec_half[i].end()); } for(int fir_k = 0; fir_k <= n / 2 + 1; fir_k++){ //枚举当前前半部分的 k for(int sec_k = 0; sec_k <= n / 2 + 1; sec_k++){//枚举后半部分的 k int p1 = 0, p2 = sec_half[sec_k].size() - 1; while(p1 < fir_half[fir_k].size() && p2 >= 0){ Instruct &f = fir_half[fir_k][p1], &s = sec_half[sec_k][p2]; if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){ p1++; //找出当前 fir_k 时满足 v_1 + v_2 == (x_g, y_g) 的最小的 p1 } else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){ p2--; //找出当前 sec_k 时满足 v_1 + v_2 == (x_g, y_g) 的最大 p2 } else{ int p1t = p1, p2t = p2; while(p1t < fir_half[fir_k].size() && fir_half[fir_k][p1t] == f){ p1t++; } while(p2t >= 0 && sec_half[sec_k][p2t] == s){ p2t--; } ans[fir_k + sec_k] += 1LL * (p1t - p1) * (p2 - p2t); //把 p1 范围长度乘上 p2 范围的长度 //小细节:本来要表示长度的话应该是 p1t - p1 + 1的,但是我们可以观察前面两个while // 在跳出之后 p1t 会比正确的 p1t 多 1,而 p2t 会比正确的 p2t 少1,因为如果 // 它们还是正确的话会又回到循环中。因我们计算长度的时候就不需要 + 1 了。 p1 = p1t, p2 = p2t; } } } } for(int i = 1; i <= n; i++){ printf("%lld\n", ans[i]); } system("pause"); } ``` 最后,希望这篇题解对你有帮助。有任何问题都可以在私信和评论区提出,我会尽量解决问题。