怎么在麻将里开科技
AzusaShirasu · · 科技·工程
前言
某群里提到希望有一篇 AI 的闲聊。
科学麻将死路一条,但是科技麻将呢?
那好,就用 AI 来开科技。
现在的日麻 AI 水平真的不差(见前文,哪怕是单纯全牌效无防守做牌都能到不错的段位)。更何况这是会学习的 AI,几十万甚至几百万的高质量对局数据全部看完,就算是普通人也会不少了。
假如你已经弄到了一个可以用的打牌 AI,如果要用 AI,首先得知道手上有什么牌。本文将介绍一种优秀的算法,训练一个识别麻将牌的模型,不论是训练过程还是使用过程都非常快速。
为了简化,先只识别类似下图的,某二字在线麻将平台的游戏截图中玩家自己的手牌(底部 14 张)。
虽然这个不是挂,并且个人感觉目前技术水平不太有人能凭一己之力做出可用的挂,但还是声明:本文只是用作技术分享,真的因为什么事情被封了与作者无关。
另外,本文不包括任何帮助玩家决策的 AI,没有挂。
准备
YOLO,全称 You Only Look Once,是一种目标实时检测算法,单次前向传播就可以在图像中识别并定位多个对象——直接在图像中预测对象的边界框和类别概率,比 CNN 系列快。
为什么不采用 OpenCV?因为太慢而且功能太少,准确性太低。OpenCV 用的一般是滑动窗口检测,计算复杂度很高。YOLO 采用的是深度学习算法,可以做到工业级的每秒几十帧的检测速度。
简单地说:这是一个能检测图像内包含的东西的算法,速度很快还准确。
还有个好处:在 Python 上,如果只是想简单地训练一个 YOLO 模型然后使用,那么它已经有包装好的版本,不需要太多底层技术代码,只需要编写少量代码就能够开始运行。
YOLO 有多个实现版本,这里推荐用 Ultralytics。它依赖 PyTorch(因此某云计算平台不能用)。可以用 pip 安装:
pip install ultralytics
要 GPU 加速的话得额外安装 CUDA:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
不论如何,装完一定要先检查 CUDA 能不能用。在 Python 内运行 print(__import__('torch').cuda.is_available()),看到 True 就行了。
数据集
数据集的文件目录如下:
dataset/
├── images/
│ ├── train/
│ └── val/
└── labels/
├── train/
└── val/
原始图像
把所有要训练的图片放在 images/train 下即可。images/val 文件夹下面也放一些图片,作为验证集。
训练集用于训练出数据,由于训练是多个轮次(epoch)的,验证集用于在每个轮次结束时检测模型的预测效果准不准。
形象地说,训练集是课本,验证集是考题。程序就是学生,会经过多轮复习,每轮复习末尾都会进行一次模拟考试(并且考完立刻忘掉所有考题)。
虽然 YOLO 似乎对这个问题不是很敏感,但是也别偷懒,别从训练集里面直接复制一部分粘贴到验证集里(即:数据泄露)。还是用上面的比方,虽然考生记不住考试的题目,但是如果平时作业里有原题,那么学生在经过很多轮的复习之后,可能只是记住这些原题的答案而不是真正掌握了方法(即:过拟合),这时候验证集的验证效果就失效了。
描述文件
为了方便 YOLO 找到数据,要准备一个 yaml 文件描述数据集。如下:
# 数据集目录
path: ./dataset
# 训练集图片的目录
train: images/train
# 验证集图片目录
val: images/val
# 需要识别的各类对象,可以自行添加更多的对象
names:
0: 甲对象名
1: 乙对象名
2: 丙对象名
对于常规的立直麻将来说,一共有 37 种牌(万饼索三种花色各九张,字牌七张,另外三种赤宝牌)。所以要分 37 类对象。
数据标注
YOLO 的功能是:给定一张图片,把图片内的对象的位置框选出来。所以训练用的数据里还需要包括这些信息。
标注文件就是用来描述一张图里有哪些对象、每个对象在哪个位置、占据多大一片空间的文件。YOLO 采用的标注格式很简单,格式(用信息竞赛选手习惯的方式)如下:
- 若干行,一行描述一个对象。先后顺序无关。
- 第
i 行,用空格分隔的五个数a_i,x_i,y_i,w_i,h_i ,其中a_i 为整数,其他为实数,表示第i 个对象的类别编号(为整数)为a_i ,中心点位于(x_i, y_i) ,在图像中所占据的宽度和高度分别为w_i,h_i 。 - 坐标的定义:图像的左上角为
(0,0) ,右下角为(1,1) 。保证x_i,y_i,w_i,h_i \in [0, 1] 。
例如如下描述文件:
2 0.50 0.50 0.10 0.12
0 0.70 0.80 0.15 0.20
表示图内有两个对象,其中一个对象是编号 2 类对象,位于图像正中间,宽度为图片的
标注文件
所有标注文件需要根据对应图片的位置放在 labels/train 或者 labels/val 目录下。文件名要相同,拓展名为 .txt。例如,在 images/train 下的图片 test1.png,要在 images/train 下有一个对应文件 test1.txt。
数据收集
接下来就可以开始收集训练数据了。收集游戏截图,把不需要的部分遮挡起来,再给每张图做标注,就可以了。
然后就会发现卡在这一步。
收集足够多的游戏截图,并且对其中的牌做出标注并不是容易的事情。如果手工做的话会非常消耗时间。
这里就可以用自动化的办法去生成数据了。
现实中的麻将会有光照角度甚至磨损等各种影响训练的原因,网络麻将就好很多,不仅材质固定(不换皮肤的话),甚至连牌摆放的位置都是固定好的。既然图像非常有规律,那为何不自己生成图像呢?
先收集所有 37 种牌的材质:
然后,找一张模板图,把无关的牌全部遮挡起来,比如(为了保护隐私把玩家 ID 也遮住了,实际训练的时候不用):
写一个 Python 脚本,把牌的材质粘贴到图上手牌的位置,就能生成一张图了,然后再顺便生成描述文件。这相当于写数据生成器。流程如下:
- 读取模板图和所有材质;
- 提前设定第 1 张手牌的起始坐标,然后设定间隔来确定之后的牌的位置。第 14 张牌间隔不太一样,直接设定坐标即可;
- 把材质放缩到合适的大小,粘贴;
- 根据粘贴的是哪种牌以及粘贴的坐标,往标注文件内写入数据。
习惯用 Pillow 库写(这种很机械代码找 AI 帮写更省事)。关键代码如下:
template = Image.open("template.jpg").convert("RGBA")
def generate_image(tiles, res_img, res_label):
selected_tiles = random.choices(TILES, k=14)
label_data = []
for i in range(13):
tile_type = selected_tiles[i]
if tile_type not in tiles:
continue
x_pos = FIRST_TILE_POS[0] + int(i * TILE_X_SPACING)
y_pos = FIRST_TILE_POS[1]
template.paste(tiles[tile_type], (x_pos, y_pos), tiles[tile_type])
img_width, img_height = template.size
x_center = (x_pos + TILE_WIDTH / 2) / img_width
y_center = (y_pos + TILE_HEIGHT / 2) / img_height
width = TILE_WIDTH / img_width
height = TILE_HEIGHT / img_height
class_id = TILE_TYPES.index(tile_type)
label_data.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
tile_type = selected_tiles[13]
if tile_type in tiles:
x_pos, y_pos = LAST_TILE_POS
template.paste(tiles[tile_type], (x_pos, y_pos), tiles[tile_type])
img_width, img_height = template.size
x_center = (x_pos + TILE_WIDTH / 2) / img_width
y_center = (y_pos + TILE_HEIGHT / 2) / img_height
width = TILE_WIDTH / img_width
height = TILE_HEIGHT / img_height
class_id = TILE_TYPES.index(tile_type)
label_data.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
template.convert("RGB").save(res_img)
with open(res_label, 'w') as f:
f.write("\n".join(label_data))
生成的数量自行把握,个人推荐训练集 600 张,验证集 60 张,然后就可以在文件夹里看到生成的图片了。不需要太像,YOLO 的泛化能力并不差。
训练
准备完数据集,就可以开始训练了。前面说过,只是简单训练一个模型的话,代码真的很短:
from ultralytics import YOLO
import torch.multiprocessing as mp
if __name__ == '__main__':
mp.freeze_support()
model = YOLO('yolov8m.pt')
model.train(
data='dataset/data.yaml',
epochs=100,
batch=12,
name='mjrivM'
)
- 第 4 行:不加就报错,加了就能运行。那就加上。
- 第 6 行:初始模型。YOLO 经过了多个版本的迭代,有 v3、v5 到 v11 等多个版本。这里选择 v8。后缀字母
m表示这是一个中等模型。后缀有如下几种:
| 后缀字母和全称 | 模型速度 | 规模与准确率 |
|---|---|---|
| n,Nano | 很快 | 非常小,但是精度低 |
| s,Small | 较快 | 轻量模型,精度较低 |
| m,Medium | 中等 | 平衡版 |
| l,Large | 较慢 | 模型较大,更精细 |
| x,Extra | 很慢 | 大型模型,极精细 |
- 第 8 行:数据描述文件。
- 第 9 行:
epoch参数即为训练的轮数,这次的模型比较小,100 轮足够了,如果需求更复杂的话可能要至少 200 轮。不是越多越好,容易过拟合。 - 第 10 行:
batch参数即为批次,根据 GPU 的吞吐量决定,不是越大越好。 - 第 11 行:
name参数就是这个模型的名字,取什么都没有关系,只是方便寻找训练出来的模型文件而已。
第一次运行会比较久,可能还有看似卡住的情况,其实没卡,程序只是在从服务器那里下载初始模型 yolov8m.pt 文件,之后就会快了;此外,程序还会给训练的数据集生成 cache 文件用来加速访问,这一步也会花时间。
看清楚程序输出的是 Warning 还是 Error / Exception。能运行就不要去管它。看到类似下面的输出就说明正式开始训练了:
训练的过程中,可以到 runs/detect/ 文件夹下,里面是每次训练的模型数据,后面跟着编号。点进去,有个 weights 文件夹,里面有 best.pt 和 last.pt 两个文件,best 就是目前多轮次训练中遇到的表现最好的文件,last 就是最后一次的文件。不需要等待训练完成,训练过程中就可以把这两个文件复制出来自行测试看看效果。
优化模型
对于这种小规模的数据不需要等很久,拿块 5080 跑个一小时不到就训练完了。训练完的模型效果并不一定理想,测试一下。准备一张测试用的图片,Python 加载模型测试,代码很简单:
from ultralytics import YOLO
model = YOLO('best.pt')
results = model.predict('test.png')
results[0].show()
for result in results:
boxes = result.boxes
for box in boxes:
class_id = box.cls.item()
class_name = model.names[class_id]
confidence = box.conf.item()
print(f"Find {class_name}, confidence={confidence:.2f})")
第 5 行会把模型预测结果可视化地展现出来,如下:
在这次的训练中,初版模型效果并不好。一般先排查数据集的原因:是不是随机性不够,是不是哪种类型的对象样本太少。如果是更复杂的任务,比如识别现实世界中的麻将,那可能还有数据泛化性不够的原因。
排除完数据集的缺陷之后,就要看看模型效果差在哪里,一般有如下几种类型:
- 误检:看错对象的类型;
- 漏检:没有识别出对象;
- 错检:正确识别出类型,但是位置标注不正确。
很明显这次初版模型存在的问题是漏检。怎么办呢?强化训练,发现这个模型容易一次性看漏两三张相似的牌,既然如此,那就构造新的一批数据,在手牌里把相似的牌堆砌在一起。生成的数据大致长这样(不用管合理性):
这一次的数据集可以稍微小一些,再次训练:
from ultralytics import YOLO
import torch.multiprocessing as mp
if __name__ == '__main__':
mp.freeze_support()
model = YOLO('yolov8m.pt')
model.train(
data='dataset/data.yaml',
epochs=100,
batch=12,
freeze=10,
name='mjrivM'
)
第 11 行多出的 freeze 参数是在冻结部分骨干层的参数,这样可以加速训练而且不容易过拟合。事实上,真实的训练里还需要调节其它参数。
再次开始训练,训练出的结果就不错了。
这样,用来识别手牌内包含的牌差不多已经足够了。
未来改进
实战可能性
识别手牌是最简单的,也已经消耗了很多功夫。在实战中,有各种类型的牌:牌河、宝牌指示牌、副露、三麻的拔北……这些都需要很大的演算力。单纯把所有牌都看出来还不够,怎么识别暗杠是个问题。
不过这不意味着本次训练出的模型没有意义,至少在何切建议问题上可以用来“拍照搜题”(市面上已经有类似的拍照给建议工具,但是还收费?不如自己做)。而且,至少打通了训练识别麻将牌的 AI 这条路,后面的只是技术的问题。
泛化
如果要用于更多用途的话,首要的还是泛化数据集。本次训练里所有的数据都是非常模式化的,例如,用这次训练出的小模型识别现实中的麻将,效果很差(无法识别倒过来的牌):
而泛化后的模型就好很多了:
当然,如果能训练到看穿牌山和别人手牌……
(珍惜手指,请勿出千)