CF2181L LLM Training

题目描述

给定一个文本数据集。你的任务是训练一个大型语言模型(LLM),并找到最小可能损失。不开玩笑。 一个文本数据集由若干文本 $t_1, t_2, \ldots, t_n$ 组成。每个文本 $t_i$ 是一个由多个 token 组成的序列。我们定义 token 集合 $T$ 为至少在一个 $t_i$ 中出现过的所有 token 的集合。此外,对于每个文本 $t_i$,还有一个位置集合 $L_i \subseteq \{1, 2, \ldots, |t_i|\}$。当 $j \in L_i$ 时,token $t_i[j]$ 由 LLM 生成。当 $j \notin L_i$ 时,该 token 由用户输入。 我们定义上下文长度为 $k$ 的 LLM 为概率模型 $P_k$,它定义了下一个 token 的概率分布,并依赖于上下文 $w$——一个长度在 $0$ 到 $k$(含)之间的、元素来自 $T$ 的序列。因此概率模型 $P_k$ 本质上是一个巨大的概率表,对每个上下文 $w \in T^{*}$,$0 \leq |w| \leq k$ 和每个 token $\text{next} \in T$,都给定 $P_k(\text{next} | w)$。这些概率应满足 $0 \leq P_k(\text{next} | w) \leq 1$,且 $\sum\limits_{\text{next} \in T} P_k(\text{next} | w) = 1$。 LLM 损失函数如下定义,对于 $P_k$: $$ \mathcal{L}_k(P_k) = \sum_{i=1}^{n} \sum_{j\in L_i} -\log_2 P_k\!\left( \underbrace{t_i[j]}_{\text{下一个 token}} \ \middle|\ \underbrace{t_i[\max(1, j-k)\ldots j-1]}_{\text{上下文}} \right) $$ 这里 $t_i[l\,..\,r] = t_i[l] t_i[l+1] \ldots t_i[r]$ 是从第 $l$ 到第 $r$ 个 token 的子串,$t_i[1..0]$ 表示空串。所以对于每个文本、每个由 LLM 生成的位置,我们将根据前 $k$ 个 token(或整个前缀,如果长度不足 $k$)的子串,加上当前 token 的概率的负对数(以 2 为底)到损失中。如果概率为 0,则负对数视为 $+\infty$。该损失函数称为(以 2 为底)的交叉熵损失(Cross Entropy Loss),只针对 LLM 生成的位置。$\mathcal{L}_k(P_k)$ 越小,LLM $P_k$ 越好。 对于每个 $0 \leq k < \max\limits_{i=1..n} |t_i|$,请计算某个具有上下文长度 $k$ 的 LLM 所能达到的最小可能的损失 $\mathcal{L}_k(P_k)$。可以证明,这个最小值是可达的且不是无穷大。

输入格式

第一行包含一个整数 $n$($1 \leq n \leq 10^5$),表示数据集中文本的数量。接下来是每个文本的描述。 第 $i$ 个文本的描述第一行为一个整数 $m_i$($1 \leq m_i \leq 3 \cdot 10^5$),表示 $t_i$ 的长度($m_i = |t_i|$)。 下一行包含 $m_i$ 个字符串 $t_{i}[1]$、$t_{i}[2]$、$\ldots$、$t_{i}[m_i]$($1 \leq |t_{i}[j]| \leq 5$),为该文本的每个 token。每个 token 由 ASCII 码 33 至 126(可打印字符)的字符构成。 下一行包含一个长度为 $m_i$ 的字符串 $\ell_i$,由字母 U 和 L 组成,编码了 $L_i$。所有位置中,字母 L 处对应由 LLM 生成,U 则由用户输入。所以 $L_i = \{j\,|\,\ell_i[j] = \texttt{L}\}$。保证每个文本最后一个 token 都是 LLM 生成的,即 $\ell_i[m_i] = \texttt{L}$。 保证所有 $m_i$ 之和不超过 $3 \cdot 10^5$。

输出格式

输出 $M = \max\limits_{i=1..n} m_i$ 个实数:对于每个 $k = 0, 1, \ldots, M-1$,输出最小可能损失 $\mathcal{L}_k(P_k)$,即所有可能的上下文长度为 $k$ 的 LLM 的最小损失。 如果你的答案的绝对误差或相对误差不超过 $10^{-6}$,即对于你的答案 $p$ 和标准答案 $q$,有 $\frac{|p - q|}{\max\{1, |q|\}} \leq 10^{-6}$,即视为正确。

说明/提示

由 ChatGPT 5 翻译