基础的搜索优化技巧
这里主要记录一些搜索优化。
众所周知,dfs 是很有用的算法。可以骗很多题的分。\ 但是它的时间复杂度太高,有些时候分很少。\ 于是我们可以加入一些优化,让它获得更高的分数,甚至 AC。
比如剪枝。
剪枝
最优化剪枝
dfs 的过程中,如果发现搜索树的当前分支已经不能产生更优的答案,那么就直接 return,不再继续搜索这一分支。
可行性剪枝
如果当前的方案已经不可行,直接 return 而不继续进行无用的搜索。
重复性剪枝
排除一些已经搜索过的分支以提高效率。
注意:以上的几种剪枝也可以推广到 bfs。具体的做法是如果当前节点不满足题目要求/不可能更新最优答案/已经搜索过就不再扩展当前节点,直接将它抛弃。
卡时
当程序快要超时时直接输出当前已经搜索到的最优解。
很多时候你会遇到有人建议你使用 clock 函数进行卡时,但这并不是一个好的方法。\ clock 函数返回的是进程运行的 CPU 时间,而不是程序实际计算消耗的时间。在本地调试时输入时的等待也会被 clock 计入 CPU 时间中,这会导致你还没来得及输入程序就终止了。在 OJ 上评测时也容易产生各种问题。
更好的方法是使用一个变量cnt
记录已经处理过的情况的总数,在cnt
超出某个限度(即将超时)时结束程序。
更改搜索顺序
常常和卡时一起使用。让最优解更早出现以防止卡时后无法得到最优解,是一个比较玄学的优化。
正经讲解:
我刚刚已经讲了一些简单但有用的优化。我知道你没有听懂。大佬请忽略\
现在让我们用一些例题具体讲讲这些优化。
【模板】01背包 | 采药
容易想到这一题的搜索解法。也容易发现搜索只能得三十分。
现在,让我们用之前讲到的方法来优化我们的搜索。
可以加入可行性剪枝与最优化剪枝。如果当前状态的总重已经超过背包容量,直接 return。如果之后所有物品的价值总和加上当前已选物品的价值总和不超过已有的最佳答案,也 return。
两个剪枝正确性显然。
第二个剪枝可以用前缀和优化求和。
可以先对物品按性价比从大到小排序,能起到一定的优化作用。
于是有了这份代码:
#include <bits/stdc++.h>
using namespace std;
int t, m, sum[105], ans;
struct things {
int w, v;
} a[105];
void dfs(int i, int w, int v) {
if (i == m + 1) {
ans = w;
return;
} else {
if (v + a[i].v <= t) {
dfs(i + 1, w + a[i].w, v + a[i].v);
}
if (w + sum[m] - sum[i] > ans) {
dfs(i + 1, w, v);
}
}
}
int main() {
cin >> t >> m;
for (int i = 1; i <= m; ++i) {
cin >> a[i].v >> a[i].w;
}
sort(a + 1, a + m + 1, [](things a, things b) -> bool {
return a.w * b.v > b.w * a.v;
});
for (int i = 1; i <= m; i++) {
sum[i] = sum[i - 1] + a[i].w;
}
dfs(1, 0, 0);
cout << ans;
return 0;
}
这份代码的优化有一些问题,直接提交会 t 掉。请大家自行思考它错误的原因。
再有一题:
P1004 [NOIP2000 提高组] 方格取数
这题的正解是动态规划。但是也有用搜索的解法。\
考虑同时搜索两条路径(与 dp 解法类似)。于是可以得到一个
朴素的搜索代码:
//creat by Code Copilot
#include <bits/stdc++.h>
using namespace std;
int n, a[10][10], ans;
void dfs(int x1, int y1, int x2, int y2, int val) {
if (x1 > n || y1 > n || x2 > n || y2 > n) return; // 越界检查
int next_val = val;
if (x1 == x2 && y1 == y2) { // 两条路径在同一格子
next_val += a[x1][y1];
} else {
next_val += a[x1][y1] + a[x2][y2];
}
if (x1 == n && y1 == n && x2 == n && y2 == n) { // 终止条件:两条路径都到达(n, n)
ans = max(ans, next_val); // 更新最大值
return;
}
// 四种路径走法
dfs(x1 + 1, y1, x2 + 1, y2, next_val);
dfs(x1 + 1, y1, x2, y2 + 1, next_val);
dfs(x1, y1 + 1, x2 + 1, y2, next_val);
dfs(x1, y1 + 1, x2, y2 + 1, next_val);
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
cin >> n;
int x, y, z;
while (cin >> x >> y >> z, x || y || z) {
a[x][y] = z;
}
dfs(1, 1, 1, 1, 0); // 从 (1, 1) 开始两条路径
cout << ans;
return 0;
}
考虑加入剪枝。
有一种比较简单的剪枝:如果当前所在点的右下角的矩形的数字之和加上已经取到的数字之和都无法超过当前最优解,就不用继续搜索了。
这个剪枝并不精确,但是效用依然强大。\ 给出一份代码:
//mycode
#include <bits/stdc++.h>
using namespace std;
int n, a[35][35], ans, maxsum[35][35];
void dfs(int x1, int y1, int x2, int y2, int lastsum) {
if (x1 > n || y1 > n || x2 > n || y2 > n) {
return;
} else {
if (lastsum + maxsum[x1][y1] + maxsum[x2][y2] <= ans) {
return;
}
int nowsum = lastsum;
if (x1 == x2 && y1 == y2) {
nowsum += a[x1][y1];
} else {
nowsum += a[x1][y1] + a[x2][y2];
}
if (x1 == n && x2 == n && y1 == n && y2 == n) {
ans = max(ans, nowsum);
return;
} else {
dfs(x1 + 1, y1, x2, y2 + 1, nowsum);
dfs(x1, y1 + 1, x2, y2 + 1, nowsum);
dfs(x1 + 1, y1, x2 + 1, y2, nowsum);
dfs(x1, y1 + 1, x2 + 1, y2, nowsum);
}
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
cin >> n;
int x = 1, y = 1, z = 1;
while (x != 0 && y != 0 && z != 0) {
cin >> x >> y >> z;
a[x][y] = z;
}
for (int i = n; i > 0; --i) {
for (int j = n; j > 0; --j) {
maxsum[i][j] = maxsum[i + 1][j] + maxsum[i][j + 1] + a[i][j] - maxsum[i + 1][j + 1];
}
}
dfs(1, 1, 1, 1, 0);
cout << ans;
return 0;
}
总耗时61ms,最大一个点30ms。
如果需要更精确的剪枝,可以考虑使用右下角矩阵的一半那个三角形代替现在的二维前缀和。也可以加入卡时的代码。
再来一道题:
合并石子堆
说明
有一堆石头质量分别为
输入格式
输入第一行只有一个整数
输出格式
输出只有一行。该行只有一个整数,表示最小的质量差。
样例
输入数据 1
5
5
8
13
27
14
输出数据 1
3
这题的搜索并不难想。但是在
搜索代码:
#include <bits/stdc++.h>
using namespace std;
int n, ans = 1e9;
vector<int> a;
void dfs(int i, int w1, int w2) {
if (i == n + 1) {
ans = min(ans, abs(w1 - w2));
return;
} else {
dfs(i + 1, w1 + a[i], w2);
dfs(i + 1, w1, w2 + a[i]);
}
}
int main() {
cin >> n;
a.resize(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
dfs(1, 0, 0);
cout << ans;
return 0;
}
考虑卡时,在搜索
卡时代码:
#include <bits/stdc++.h>
using namespace std;
int n, ans = 1e9, cnt;
vector<int> a;
void dfs(int i, int w1, int w2) {
cnt++;
if (cnt > 50000000) {
cout << ans;
exit(0);
}
if (i == n + 1) {
ans = min(ans, abs(w1 - w2));
return;
} else {
dfs(i + 1, w1 + a[i], w2);
dfs(i + 1, w1, w2 + a[i]);
}
}
int main() {
cin >> n;
a.resize(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
sort(a.begin() + 1, a.begin() + n + 1, greater<int>());
dfs(1, 0, 0);
cout << ans;
return 0;
}
但是这样会有 WA 的测试点。于是考虑更改搜索顺序。
容易想到可以将石子的质量由大到小排序,然后交替合并,哪边少就往哪边加。\ 这是贪心的思想。但是此时可以用来优化搜索。先考虑向少的那边加石子,再搜索另一个分支。
经过这样的优化之后,我们的卡时代码就能 AC 了。
#include <bits/stdc++.h>
using namespace std;
int n, ans = 1e9, cnt;
vector<int> a;
void dfs(int i, int w1, int w2) {
cnt++;
if (cnt > 50000000) {
cout << ans;
exit(0);
}
if (i == n + 1) {
ans = min(ans, abs(w1 - w2));
return;
} else {
if (i % 2) {//这里的搜索顺序和之前说的不一样,不是先考虑向石子少的那边加石子,而是交替添加石子
dfs(i + 1, w1 + a[i], w2);
dfs(i + 1, w1, w2 + a[i]);
} else {
dfs(i + 1, w1, w2 + a[i]);
dfs(i + 1, w1 + a[i], w2);
}
}
}
int main() {
cin >> n;
a.resize(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
sort(a.begin() + 1, a.begin() + n + 1, greater<int>());
dfs(1, 0, 0);
cout << ans;
return 0;
}
迭代加深搜索
有时我们需要找到一个最优解。
确定最优解是多少(如背包的总价值,最短路中的路径长短等)可以帮助我们 dfs。\ 但在没有找到最优解之前一般无法确定最优解的值。\ 所以使用迭代加深搜索(iddfs),从小到大枚举每一种可能的代价,假设它就是最优解的代价进行 dfs。\ 第一个被找到的可行解就是最优解。
iddfs 的优点是可以避免 bfs 庞大的空间消耗,也不会“陷”进某些没有答案的分支里。
在数据较大时,iddfs 的时间复杂度和运行时间与 bfs 相近。
P1874 快速求和
这是一道 dp 题,但是数据范围较小,可以使用 dfs 通过。
使用 iddfs,从 1 到 s.size() - 1
枚举加号的个数,然后进行搜索。
朴素搜索代码(90pts):
#include<bits/stdc++.h>
using namespace std;
basic_string<char> s;
int n;
void dfs(int i, int cnt, int tot, int sum) {//当前的位置,现在加号的数量,加号的总数,已有的和
if (i < s.size()) {
for (int j = i; j <= min((int) s.size() - 1, i + 5); ++j) {
dfs(j + 1, cnt - 1, tot, sum + stoi(s.substr(i, j - i + 1)));//用两个 STL 内置函数计算当前加号到下一个加号间的数字并将它添加到和中
}
} else if (cnt == 0 && sum == n) {
cout << tot;
exit(0);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> s >> n;
for (int i = 0; i < s.size(); ++i) {//最少没有加号,最多每个数字间都有加号
dfs(0, i + 1, i, 0);//因为在第一个数字之前加入了一个加号,所以 cnt 应该加一。
}
cout << -1;
return 0;
}
但是朴素的搜索并不能通过本题,所以使用以下几种优化:
- 卡时
- 使用一个变量记录当前加号到下一个加号间的数字,避免频繁使用 substr 和 stoi。这样可以使时间优化到原来的 1/4 左右。
- 针对 hack#3 进行的优化。此优化针对前导零进行了剪枝。
放一个代码。
#include<bits/stdc++.h>
using namespace std;
basic_string<char> s;
int n;
void dfs(int i, int cnt, int tot, int sum) {
static int t = 0;
t++;
if (t > 100000000) {//卡时
cout << -1;
exit(0);
} else if (sum <= n) {
if (i < s.size()) {
long long tmp = 0;//存储两个加号间的数字
for (int j = i; j < s.size(); ++j) {
tmp *= 10;
tmp += s[j] - '0';
if (tmp > n - sum) {//剪枝
break;
} else if (tmp != 0) {//由于前导零不影响后面的数字的大小,所以可以节省一个加号
dfs(j + 1, cnt - 1, tot, sum + tmp);
}
}
if (tmp == 0) {
dfs(s.size(), cnt - 1, tot, sum);
}//如果从某个加号到末尾之间全是零,也可能是一种合法的方案
} else if (cnt == 0 && sum == n) {
cout << tot;
exit(0);
}
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> s >> n;
for (int i = 0; i < s.size(); ++i) {
dfs(0, i + 1, i, 0);
}
cout << -1;
return 0;
}
于是我们就可以 AC 本题了。
可以试试这道题:
组合总和 II (Combination Sum II)
问题描述:
给定一个候选数字集合 candidates
和一个目标数 target
,找出候选数字集合中所有和为目标数的组合。
每个数字在每个组合中只能使用一次。
优化技巧:
- 剪枝:在搜索过程中,如果当前候选数字之和已经超过目标值,立即返回。
- 去重:在搜索过程中避免产生重复的组合。
记忆化搜索
在讲解了以上的几种优化之后,还有一种非常重要的优化——记忆化搜索。
记忆化搜索本质上是用递归实现的 dp。由于使用递归实现,有时会比递推的 dp 更好想。
记忆化搜索的实现步骤:
- 先写出爆搜程序。
- 将爆搜程序改为“不需要外部变量”“无需副作用”的数学式函数(只要给定 dfs 头部的几个参数,dfs 的返回值就是确定的)。这里 dfs 的返回值就是我们实际要求的答案,如最优解个数、费用最小值等。
- 添加记忆化数组记录 dfs 的返回值。
- 在 dfs 中添加判断条件,如果当前状态已经搜索过那么就返回记录的答案。否则搜索当前状态并记录答案。
如这样的一道例题:
小木棍
乔治有一根小木棍,他把这根木棍随意砍成更小的木棍,直到每段的长都不超过一定值,在经过这一流程后,他得到了
但小凯过来了,他本来是想拿这根小木棍去玩游戏的,现在小凯很生气,因此小凯向乔治提出了一个问题。
小凯把这些短木棍按照一定顺序排列后,他需要乔治把这些短木棍按照这个顺序来划分成若干个非空连续子序列。并且还需要满足如下要求: 若一共分成了
小凯想让乔治回答有多少种划分方式能够满足如上需求。乔治很苦恼,因此他需要你的帮助来解决这个问题。
输入格式
第一行一个正整数
输出格式
输出划分方案数对
数据范围
对于
这题的爆搜并不难写。我们可以设dfs(i,j,sum)
代表前j
个物品共分了i
段,当前段内的总和是sum
时的方案数,于是得到如下的程序:
#include <bits/stdc++.h>
using namespace std;
#define mod 998244353
int n, a[2005];
unsigned long long ans;
void dfs(int i, int j, unsigned long long sum) {
sum += a[j];
if (j == n) {
if (sum % i == 0) {
ans++;
}
if (ans > mod) {
ans -= mod;
}
} else {
dfs(i, j + 1, sum);
if (sum % i == 0) {
dfs(i + 1, j + 1, 0);
}
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
dfs(1, 1, 0);
cout << ans;
return 0;
}
但是最小一档的n
都可能取到
考虑记忆化搜索。
我们已经得到了爆搜的程序,现在将dfs
转化成数学式的函数。
显然的,当dfs
函数达到边界时,如果当前块的总和可以被当前块的编号整除,那么符合题意的方案数就是
然后添加记忆化数组。
这里是不能直接用数组的,空间会炸。但是对于每一个i
与j
,能用到的sum
的值都是很少的,于是可以使用 map 优化空间。
于是可以得到这样一份40分的代码。
#include <bits/stdc++.h>
using namespace std;
#define mod 998244353
int n, a[2005];
map<unsigned long long, unsigned long long> opt[2005][2005];
unsigned long long dfs(int i, int j, unsigned long long sum) {
sum += a[j];
if (j == n) {
if (sum % i == 0) {
return 1;
} else {
return 0;
}
} else {
auto item = opt[i][j].find(sum);
if (item != opt[i][j].end()) {//如果已经记录有当前状态的答案
return item->second;//那么就直接返回已有的答案
} else {
opt[i][j].emplace(sum, dfs(i, j + 1, sum));//否则新开一个节点存储答案
if (sum % i == 0) {//如果另一个分支合法
opt[i][j][sum] += dfs(i + 1, j + 1, 0);//那么加上它的答案
}
return opt[i][j][sum] %= mod;//记得取模
}
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
cout << dfs(1, 1, 0);
return 0;
}
时间复杂度
启发式搜索
在搜索的过程中,我们会对搜索树的各分支作出估价,先搜索更优的分支,再搜索其他分支,无用的分支直接剪去不再搜索。
这就是启发式搜索。读者可能会注意到我们之前所讲的几种优化也可以归入它的范畴之内。所以它们也是启发式搜索。
但是之前几种优化的启发式方案都是写程序时预先定好的,不够灵活。
下面我们将介绍几种启发式搜索。
A*
A* 是 dijkstra 的一种优化。
普通的 dijkstra 将待扩展的节点按照已经走过的路程排序,每次取出路程最小的节点。
而 A* 则添加了一个估价函数用于估测当前状态到目标的距离。对节点的排序方案不再是已经走过的距离,而是已经走过的距离与到终点的估价的总和。
这能优化 A* 的性能。比较显然的,估价函数的估价与到终点的真实距离越接近,A* 的优化就越明显。
不太显然的,估价函数的值必须小于等于当前节点到终点的最小距离,否则 A* 无法得到最优解。\ 我并不清楚如何证明这个结论,于是做一个感性理解:\ 如果我们现在要用 A* 找起点与终点之间的最短路径。有两条路径,一长一短。\ 尝试 hack A*。最坏的情况是较短的那条路径上的估价每次都刚好是剩余的路程,而较长路径的估价都为零。这样在短的路径上走的时候每个点走过的路程与估价的和都是路径的总长,而较长路径上的节点走过的路程与估价的和都是已经走过的路程。\ 但是随着搜索的不断进行,长的那条路径上的点走过的路程总会超过短路径的总长。这时在优先队列中短路径上的点就会被扩展,A* 就能得到正确的结果。
A* 的一道经典题目是八数码。
八数码
这道题可以使用 bfs 通过。但是时间比较吃紧。考虑使用 A*。
可以设置估价函数为当前状态与目标状态之间错位的数字的个数。\ 于是可以得到这样的一份代码:
#include <bits/stdc++.h>
using namespace std;
string start, e = "123804765";//e是目标状态
int ans;
int h(string x) {//估价函数
int res = 0;
for (int i = 0; i < 9; ++i) {//统计错位数
if (x[i] != e[i]) {
res++;
}
}
return res;
}
struct node {
string x;
int dis;
node(string s, int d) : x(std::move(s)), dis(d) {}
};
struct cmp {//优先队列的比较类
bool operator()(node a, node b) {//按已经走过的路程与估价的和排序
return (a.dis + h(a.x)) > (b.dis + h(b.x));//从小到大
}
};
priority_queue<node, vector<node>, cmp> q;
set<string> s;
void A_star() {
q.emplace(start, 0);
s.insert(start);
while (!q.empty()) {
node qaq = q.top();
q.pop();
if (qaq.x == e) {//当找到答案时直接输出
ans = qaq.dis;
cout << ans;
exit(0);
} else {
int i = qaq.x.find('0');//找到空格的位置
string tmp = qaq.x;
if (i > 2) {//枚举每一种可能的变换
swap(tmp[i], tmp[i - 3]);
if (!s.count(tmp)) {
q.emplace(tmp, qaq.dis + 1);
s.emplace(tmp);
}
}
if (i < 6) {
tmp = qaq.x;
swap(tmp[i], tmp[i + 3]);
if (!s.count(tmp)) {
q.emplace(tmp, qaq.dis + 1);
s.emplace(tmp);
}
}
if (i % 3) {
tmp = qaq.x;
swap(tmp[i], tmp[i - 1]);
if (!s.count(tmp)) {
q.emplace(tmp, qaq.dis + 1);
s.emplace(tmp);
}
}
if (i % 3 < 2) {
tmp = qaq.x;
swap(tmp[i], tmp[i + 1]);
if (!s.count(tmp)) {
q.emplace(tmp, qaq.dis + 1);
s.emplace(tmp);
}
}
}
}
}
int main() {
cin >> start;
A_star();
return 0;
}
这就是一种 A* 的代码。
还可以使用另一种估价函数:将每个数字与目标状态中它的位置的曼哈顿距离之和作为估值。\ 因为估价更接近实际,这样的程序更快。
#include <bits/stdc++.h>
using namespace std;
string start, e = "123804765";//e是目标状态
int ans;
constexpr int pos[] = {4, 0, 1, 2, 5, 8, 7, 6, 3};
int h(string x) {//估价函数
int res = 0;
for (int i = 0; i < 9; ++i) {//对于每一个数,将它与目标位置的曼哈顿距离累加到返回值中
res += abs(pos[x[i] - '0'] % 3 - i % 3) + abs(pos[x[i] - '0'] / 3 - i / 3);
}
return res;
}
struct node {
string x;
int dis;
node(string s, int d) : x(std::move(s)), dis(d) {}
};
struct cmp {//优先队列的比较类
bool operator()(node a, node b) {//按已经走过的路程与估价的和排序
return (a.dis + h(a.x)) > (b.dis + h(b.x));//从小到大
}
};
priority_queue<node, vector<node>, cmp> q;
set<string> s;
void A_star() {
q.emplace(start, 0);
s.insert(start);
while (!q.empty()) {
node qaq = q.top();
q.pop();
if (qaq.x == e) {//当找到答案时直接输出
ans = qaq.dis;
cout << ans;
exit(0);
} else {
int i = qaq.x.find('0');
string tmp = qaq.x;
if (i > 2) {//枚举每一种可能的变换
swap(tmp[i], tmp[i - 3]);
if (!s.count(tmp)) {
q.emplace(tmp, qaq.dis + 1);
s.emplace(tmp);
}
}
if (i < 6) {
tmp = qaq.x;
swap(tmp[i], tmp[i + 3]);
if (!s.count(tmp)) {
q.emplace(tmp, qaq.dis + 1);
s.emplace(tmp);
}
}
if (i % 3) {
tmp = qaq.x;
swap(tmp[i], tmp[i - 1]);
if (!s.count(tmp)) {
q.emplace(tmp, qaq.dis + 1);
s.emplace(tmp);
}
}
if (i % 3 < 2) {
tmp = qaq.x;
swap(tmp[i], tmp[i + 1]);
if (!s.count(tmp)) {
q.emplace(tmp, qaq.dis + 1);
s.emplace(tmp);
}
}
}
}
}
int main() {
cin >> start;
A_star();
return 0;
}
后记
受限于本人能力,本文所讲的技巧相对基础,但基础的往往也是有用的。\ 这里没有习题,搜索的优化也不是仅通过几道题就能完全掌握的。但搜索优化是用途极为广泛的技巧。在日常的训练中,它们的题目无处不在。我们应该在所有搜索题目中都思考如何优化,这样才是训练我们的技巧的正确方式。\ 希望本文所介绍的技巧能够让你得到更多的分数,AC 更多的题目。