当你迷茫的时候
AzusaShirasu · · 科技·工程
引子
为什么会迷茫?不是什么都做不了,而是因为能做的事情太多了,想选的太多而无所适从。
唐代诗人岑参的名作《白雪歌送武判官归京》中曾经写过如下名句:
北风卷地白草折,胡天八月即飞雪。\ 忽如一夜春风来,千树万树梨花开。\ 散入珠帘湿罗幕,狐裘不暖锦衾薄。\ 将军角弓不得控,都护铁衣冷难着。\ 瀚海阑干百丈冰,愁云惨淡万里凝。\ 中军置酒饮归客,胡琴琵琶与羌笛。\ 纷纷暮雪下辕门,风掣红旗冻不翻。\ 轮台东门送君去,去时雪满天山路。\ 山回路转不见君,雪上空留马行处。
没错,每一句都是名句,不信随便去抽个高三学生,问他会不会背?请问高考生有多少啊,那么多人会背的句子还不是名句?
刚从 OI 退役下来不会背?如果退役下来仍然需要背这个,而且还因为背不下来而苦恼的话……不想放弃 OI 又不能不读文化课,既要又要,想选又不敢 / 不甘,这就是迷茫。
前言
立直麻将里也经常有迷茫的时候,比如下面一手牌怎么切?
虽然之前已经研究过怎么科学地计算向听数和有效牌了,但是对实战的指导性不强,毕竟单纯知道向听数也没用。nekobean - mahjong-cpp 这个项目给出了用 dp 计算何切的期望得点,但是从实践情况来看,这个项目有个问题是有点慢(计算三向听在我的设备上居然要快四秒),还有可以改进的地方。
最大的可以改进的部分是:这个项目部署之前需要打几张庞大的表才能基于此运行,而且运算比较吃 CPU。而大部分时候,玩家只关心的是策略(打哪张牌)而非细节(期望得点、听率与和率的具体值)。那用什么办法改进呢?
仍然是 AI。
设计一个 AI 模型来完成指定任务,需要的思维也许没有一道信息竞赛题那么繁复,但也相当细致全面。本文大部分篇幅是在讲述怎么设计和优化的,如果只想要结论,可以跳读本文。
数据编码
何切问题是个静态的问题,也就是说只关心当前状态,不关心先前动作,因此训练起来会相对容易一些。把何切问题抽象一下,输出参数有两个:要切掉的牌、向听数(不是直接提供策略支持,但是很有用)。而输入参数有这么几个:
- 牌山剩余量(通过统计牌桌上出现的牌,可以得知牌山大概每种牌还剩多少);
- 手上的手牌,14 张;
- 有一些牌价值更高(宝牌、红宝牌)。
第一步仍然是把牌编码为整数,这里照样使用
牌山剩余量将作为一个
关于手上的手牌,下面有两种编码策略:
- 稀疏索引:编码为一个
14 维向量,每一维的取值为0 到33 ,表示牌的编号; - 稠密向量:编码为一个
34 维向量,每一维的取值为0 到4 ,表示对应牌的数量。
哪一种更好?绝对是后者。首先,牌山剩余量已经编码为了
此外还有一个问题,稀疏索引中,[3, 2, 1] 和 [1, 2, 3] 表示上是不一样的,可本质是相同的;虽然可以排序,但是模型仍然需要额外精力去学习这些值之间的关系——比如相邻连续牌(顺子)、相同牌计数(是对子还是刻子),并不合适。
同理,何切问题的答案(切掉哪张牌)也可以用这样的方式表示。最直观的想法是:一个值,
然后是宝牌的表示。顺着思路,设计一个
这里需要注意:需要让模型意识到自己手上宝牌的价值,以及牌山中宝牌的潜力。把宝牌向量分作两个向量,一个专门表示牌山中的宝牌,一个表示手上的宝牌。
关于向听数,一手牌的向听数最多是
综上所述,将数据按如下方式编码:
- 牌山剩余量:记作
yama,34 维向量,第i 维表示编号为i 的牌剩余数量; - 牌山宝牌剩余量:记作
d_yama,34 维向量,第i 维表示编号为i 的牌的价值,每一番加一(注意是每一番,双宝牌一张能贡献2 ,红宝牌在原基础上额外贡献1 ); - 手牌:记作
tehai,34 维向量,第i 维表示编号为i 的牌在手牌内的数量; - 手牌宝牌:记作
d_tehai,34 维向量,第i 维表示编号为i 的牌(如果在手牌内)提供的价值,也是每一番加一。 - 切牌:记作
kiru,34 维向量,切牌答案对应的那一维是1 ,其他都是0 ; - 向听数:记作
shanten,一个单独的整数,表示目前手牌的向听数。
如上,数据编码的方式是一个不可忽视的课题。好的数据编码能帮助模型更有效地提取特征。设计数据的编码方式有很多学问,最基础的一个思考方式就是如上:先设计几个编码,然后看把两个向量参数
数据收集
这里就涉及到一个数据质量的问题。一开始从天凤下载到了牌谱数据,然后解析,把牌局的真实情况转化为何切。
很快发现数据质量有问题。除去人类的何切判断不一定准确,还有一个本质上的问题:立直麻将是注重防守的,相当多的时候并不是根据何切策略在切牌,而是在兜牌甚至完全弃和。人工挑选也不可能。
还记得之前训练 YOLO 视觉模型时候的解决办法吗?如果现实太麻烦、人工有困难,那就用机器生成!
现成的何切计算器已经有了,就用那个 mahjong-cpp 当做生成的核心。这样保证何切判断不会像人类一样有潜在的质量问题。
那么,何切题面怎么生成呢?随机挑一张当宝牌指示牌、随机生成一副手牌、然后随机把一些牌扔掉(相当于减少牌山剩余量)?
这行不通。之前说过,随机生成一副 14 张手牌的向听数期望是
那,先做一个向听数计算器,生成的时候按需求生成两向听、一向听和听牌的数据?也行不通,具体原理已经懒得探究了,会生成一堆七对子……但是实际做牌的时候一般不会做成这么多七对子,所以,如果能模拟牌局进行就好了。
所以最终的解决方案是:写代码模拟牌局,四个玩家决策方式是:没有防守,全部用何切计算器决策!代码并不是很多,Python 写的话也就是下面这么多。
这里不放出具体代码了,不过给出模拟牌局逻辑:
- 随机生成牌山,取 1 张当做宝牌指示牌,再取四组手牌,每组 13 张牌,是四个玩家的初始手牌;
- 按顺序给每个玩家发牌,每个玩家拿到牌之后,立刻根据场上信息调用一次何切计算器,然后打出一张牌放到牌河里(全体玩家可见);
- 没有吃碰杠之类的鸣牌(何切问题基本都是门清),一旦有玩家听牌,停止模拟,把牌谱记录下来;没有听牌的话就一直打到场上摸够 70 张牌流局。
然后就能有不错的数据了。平均一次模拟会打 12 巡,又有四个玩家,相当于一次模拟牌局可以提供大约 50 组何切题数据。小小跑个一千组,保存成 json 大概长这样。
按照上面的方式,稍微整一整(大概跑了三十小时,mahjong-cpp 在三向听的时候很慢)有二十六万组数据。这里提醒一下:多开线程的话注意安全关闭的问题,不然会有写一半被关掉的行然后让处理数据的程序报错。
训练
这里就采用 ResNet,把四个输入参数拼成一个
class MahjongResNet(nn.Module):
def __init__(self, input_channels=4):
super(MahjongResNet, self).__init__()
self.in_channels = 64
self.conv = nn.Conv1d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm1d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(64, 2)
self.layer2 = self._make_layer(128, 2, stride=2)
self.layer3 = self._make_layer(256, 2, stride=2)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.kiru_head = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 34),
nn.Sigmoid()
)
self.shanten_head = nn.Sequential(
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def _make_layer(self, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels:
downsample = nn.Sequential(
nn.Conv1d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(out_channels)
)
layers = []
layers.append(ResidualBlock1D(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(ResidualBlock1D(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = x.view(-1, 4, 34)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
kiru_output = self.kiru_head(x)
shanten_output = self.shanten_head(x)
return kiru_output, shanten_output.squeeze()
残差块用 1D 卷积,这样比较好捕捉牌与牌之间的关联(比如顺子,要求相邻的牌):
class ResidualBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock1D, self).__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm1d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
还需要写一个数据加载类。这部分代码比较简单而且机械,就留给 LLM 写出来就好了(没必要手写这些东西折磨自己!因为真正折磨的东西在后面!)
class MahjongResNetDataset(Dataset):
def __init__(self, data, scaler=None, fit_scaler=False):
self.data = data
X = []
y_kiru = []
y_shanten = []
for item in data:
features = np.stack([
item['yama'],
item['d_yama'],
item['tehai'],
item['d_tehai']
], axis=0)
X.append(features)
y_kiru.append(item['kiru'])
y_shanten.append(item['shanten'])
X = np.array(X, dtype=np.float32)
y_kiru = np.array(y_kiru, dtype=np.float32)
y_shanten = np.array(y_shanten, dtype=np.float32)
# 特征标准化 (按特征组标准化)
if fit_scaler:
self.scalers = [StandardScaler() for _ in range(4)]
for i in range(4):
self.scalers[i].fit(X[:, i, :])
else:
self.scalers = scaler
if self.scalers is not None:
for i in range(4):
X[:, i, :] = self.scalers[i].transform(X[:, i, :])
self.X = X
self.y_kiru = y_kiru
self.y_shanten = y_shanten
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return {
'features': self.X[idx],
'kiru': self.y_kiru[idx],
'shanten': self.y_shanten[idx]
}
还有一些训练要用的东西,比如损失函数和优化器。对于何切,是多标签分类问题,用交叉熵 BCE 损失函数;向听数预是回归问题,用均方误差就可以了。优化器用经典的 Adam 配上经典的学习率:
criterion_kiru = nn.BCELoss()
criterion_shanten = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
然后就可以开始训练了(训练用的代码很模板,找 LLM 写,把上面说的这些告诉它就可以自动帮你完成代码了)。
优化
效果不是很理想。模型似乎在乱打。
测了几组就会发现一个共性的问题:模型似乎很喜欢切掉特定位置的几个牌。此外,何切问题其实次优解和最优解有时候相差不大,模型经常在几个解之间犹豫不决。模型需要优化。
超参数调整
网络层数,学习率之类的参数,自行调整一下。在本次训练中调整起到的效果有限(而且还花了一堆训练时间)。
BCE 改进:带权
之前提到过,何切问题可能有多个答案,所以可以用 top-k 策略优化损失函数。问 LLM 可得代码:
class WeightedBCELoss(nn.Module):
def __init__(self, top_k=3, primary_weight=2.0, secondary_weight=1.0):
super().__init__()
self.top_k = top_k
self.primary_weight = primary_weight
self.secondary_weight = secondary_weight
self.bce = nn.BCELoss(reduction='none')
def forward(self, input, target):
# 计算基础BCE损失
loss = self.bce(input, target)
# 对top-k标注给予更高权重
batch_size = target.size(0)
for i in range(batch_size):
# 找出标注中概率最高的top-k张牌
_, topk_indices = torch.topk(target[i], self.top_k)
# 对这些位置应用更高权重
loss[i][topk_indices] *= self.primary_weight
# 对其他位置应用标准权重
other_mask = torch.ones_like(target[i], dtype=torch.bool)
other_mask[topk_indices] = False
loss[i][other_mask] *= self.secondary_weight
return loss.mean()
之前优化器和损失函数初始化的部分稍微修改一下:
criterion_kiru = WeightedBCELoss(top_k=3, primary_weight=2.0, secondary_weight=0.5)
criterion_shanten = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
数据强化
数据强化,是一种对已有的数据进行一些变换后加入数据集,从而得到更泛化也更丰富的数据集的数据处理策略。
强化的意义不仅是扩充数据集,还能在一定程度上提示模型去注意一些特征——例如,要如何让模型知道,花色本身意义并不大,故可以对现有的数据进行花色变换,例如把万字牌和饼子牌数量整体对换。代码:
def augment(item):
def part_swap(arr):
a, b = random.sample([0, 1, 2], 2)
arr[a*9:(a+1)*9], arr[b*9:(b+1)*9] = arr[b*9:(b+1)*9], arr[a*9:(a+1)*9]
return arr
for i in item:
if i == 'shanten': # 向听数不是向量不用改变
continue
part_swap(item[i])
如此产生新数据,结合原版数据,模型就更有可能学习到“花色基本无关”这一点。(说是基本无关,是因为绿一色役满要求必须是索子牌,不过考虑到绿一色十分少见,这里就忽略)。
实战
包装一下,再把那个计算有效牌(即计算进张)的程序也引进来,运行时界面大概长这样:
事实上这个程序打打低端局已经差不多了,如果是天凤段位估计能打到五六段。之前不是说一堆人打不过天凤牌理嘛……
但是扛不住恶调
攻守兼备的模型也扛不住
我谢谢你啊发牌员
参数文件
不给😋自己练
(想给但是不太敢)
(其实本文已经把我自己做的一个真正意义上的模型的一代的关键原理公开了)
(有一个也是基于纯牌效的麻将辅助机器人在 github 上是全开放的,那个更强)
(珍惜自己账号)