Meta提出了一种透过多token预测(Multi-token Prediction)来训练更好、更快的大型语言模型的方法。这篇论文的重点如下:
训练语言模型同时预测多个未来的token,可以提高样本效率(sample efficiency)。
在推论阶段,使用多token预测可以达到最高3倍的加速。
论文的主要贡献包括:
- 提出了一种简单的多token预测架构,在训练时间和内存使用上没有额外开销。
实验证明,这种训练范式在大规模模型(最高达130亿参数)上是有效的,平均可以解决大约15%以上的编程问题 - 多token预测使得自我推测解碼(self-speculative decoding)成为可能,在各种批次大小下将模型的推论速度提高了最多3倍。
https://arxiv.org/pdf/2404.19737
动机与目的
传统的语言模型通常使用下一个token预测(next-token prediction)的方式进行训练,即根据前面的token序列,预测下一个最可能出现的token。然而,这种训练方式可能导致模型过度关注局部的模式,忽略了长程的依赖关系。为了解决这个问题,本文提出了多token预测(multi-token prediction)的训练方法,同时预测未来的多个token,以提升语言模型的训练效率和性能。
方法原理
模型架构
语言模型使用一个共享的模型主体(shared model trunk),并在其上添加n个独立的输出头(output head),分别预测未来的n个token。
在训练时,模型在每个位置同时预测未来的n个token,使用n个独立的loss项。
为了减少GPU内存用量,作者巧妙地调整了前向/反向传播的顺序。模型依序计算每个输出头的前向和反向传播,同时累积主体的梯度,避免同时储存所有n个庞大的logit向量。
推论时,可以只用第一个输出头(也就是下一个token的预测),其余输出头可选择性地用于加速推论(称为self-speculative decoding)。
训练目标
在训练时,模型在每个位置同时预测未来的 n n n个token,使用 n n n个独立的cross-entropy loss项。假设输入的token序列为 x 1 , x 2 , . . . , x t , x_1, x_2, ..., x_t, x1,x2,...,xt,模型的训练目标可以表示为:
L n = − Σ t l o g P ( x t + 1 , . . . , x t + n ∣ x 1 , . . . , x t ) L_n = - Σ_t log P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t) Ln=−ΣtlogP(xt+1,...,xt+n∣x1,...,xt)
其中, P ( x t + 1 , . . . , x t + n ∣ x 1 , . . . , x t ) P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t) P(xt+1,...,xt+n∣x1,...,xt)表示在给定前 t t t个token的条件下,未来 n n n个token的联合概率分布。将这个联合概率分解为 n n n个条件概率的乘积,可以得到:
L n = − Σ t [ l o g P ( x t + 1 ∣ x 1 , . . . , x t ) + l o g P ( x t + 2 ∣ x 1 , . . . , x t ) + . . . + l o g P ( x t + n ∣ x 1 , . . . , x t ) L_n = - Σ_t [log P(x_{t+1} | x_1, ..., x_t) + log P(x_{t+2} | x_1, ..., x_t) + ... + log P(x_{t+n} | x_1, ..., x_t) Ln=−Σt[logP(xt+1∣x1,...,xt)+logP(xt+2∣x1,...,xt)+...+logP(xt+n∣x1,...,xt))
每个条件概率 P ( x t + i ∣ x 1 , . . . , x t ) P(x_{t+i} | x_1, ..., x_t) P(xt+i∣x1,...,xt)由一个独立的输出头计算得到。
训练技巧
为了减少GPU内存的使用量,作者巧妙地调整了前向/反向传播的顺序。模型依序计算每个输出头的前向和反向传播,同时累积主体的梯度,避免同时储存所有n个庞大的logit向量。这种技巧使得多token预测模型的训练几乎不增加额外的计算和存储开销。
推论过程
在推论阶段,可以只使用第一个输出头(即下一个token的预测),其余输出头可选择性地用于加速推论。这种加速技术称为self-speculative decoding,通过并行计算多个输出头的预测结果,可以提高推论的效率。
实验结果
作者在多个编码和自然语言任务上评估了多token预测模型的性能,并与传统的下一个token预测模型进行了比较。
编码任务
在HumanEval和MBPP两个编码数据集上,多token预测模型显著优于基准模型,尤其在大模型(如13B参数)上提升更加明显。4个token的预测在综合表现上最佳,在HumanEval上pass@100提升了4.1%,在MBPP上pass@1提升了3.8%。此外,训练多个epoch时,多token预测的优势仍然存在。
自然语言任务
在自然语言任务上,多token预测也带来了改进,特别是在需要生成较长文本的摘要和自然语言数学任务。在8个摘要数据集上,2个token的预测平均将ROUGE-L提升了0.51,4个token的预测平均提升了0.46。在GSM8K自然语言数学数据集上,2个token的预测模型显著优于基准模型。
字符级训练
为了验证多token预测有助于学习更长程的依赖关系,作者进行了字符级(byte-level)的训练实验。结果表明,8个字符的多token预测模型在HumanEval上pass@1的表现比下一个字符预测模型高出20%,在MBPP上高出67%。这说明多token预测能够捕捉更长距离的模式和依赖关系。
模型微调
使用预训练的多token预测模型进行微调,也能在下游任务上取得优于基准模型的成果。在CodeContests数据集上,4个token预训练的模型在pass@k上全面超过了下一个token预训练的模型。
在编码(coding)任务上,多token预测模型在HumanEval和MBPP数据集上的表现显著优于基准模型,尤其在大模型(如13B参数)上提升更加明显。
在自然语言任务上,多token预测也带来了改进,特别是在需要生成较长文本的摘要和自然语言数学任务。
多token预测有助于模型学习更长程的依赖关系。在字符级(byte-level)的训练中,8个字符的多token预测大幅优于下一个字符预测。
实验显示,4个token的预测在综合表现上最佳。此外,训练多个epoch时,多token预测的优势仍然存在。
使用训练好的多token预测模型进行微调(如在CodeContests数据集上),也能取得优于基准模型的成果。
额外的输出头可用于self-speculative decoding,在推论阶段提供最高3倍的加速。
结论与讨论
本文提出了一种简单而有效的语言模型训练方法——多token预测,通过同时预测未来的多个token,促进模型学习更长程的依赖关系。实验结果表明,这种方法在编码和自然语言任务上带来了显著的性能提升,尤其对大模型和较长文本的生成任务效果更佳。多token预测几乎不增加训练成本,却能提高训练和推论效率,值得进一步探索。
作者认为,这项工作为寻找更有效的语言模型训练方法开辟了新的方向。未来的研究可以探索以下几个方面:
- 在更大规模的数据集和模型上验证多token预测的有效性。
- 研究最优的token预测数量n,以及如何自适应地选择n。
- 设计更高效的多token预测架构,如使用单一的输出头来预测多个token。
- 将多token预测与其他辅助训练目标结合,如掩码语言建模(masked language modeling)。
多token预测是一种前景广阔的语言模型训练方法,有望帮助构建更强大、更连贯的语言模型,推动自然语言处理领域的发展。
以下是我对这项工作的一些想法:
Meta最近提出了一种简单而有效的语言模型训练方法—多token预测(Multi-Token Prediction,简称MTP)。传统的语言模型通常每次只预测一个token,而MTP则在每个时间步预测多个token,从而提高训练效率。
核心思想:
在每个时间步,模型预测接下来的n个token,而不是1个
将这n个token打包成一个单独的预测目标,用一个特殊的分隔符隔开
模型的输出是长度为n的token序列,用交叉熵损失函数优化
优点:
预测多个token,捕捉更长距离的依赖,学到更强的上下文表征
并行化程度高,加快训练速度,节省显存
实现简单,几乎不增加模型参数量
在下游任务上finetune,相比传统方法能取得更好的效果
实验结果表明,相比标准的next token prediction,MTP能以更低的训练成本取得更好的性能。比如在相同的计算预算下,MTP的WikiText-103困惑度比传统方法低15%以上。
总之,多token预测是一种简洁而强大的语言模型训练范式。通过预测多个token,它能学到更丰富的上下文信息。同时并行化程度高,训练高效。Meta的这项工作为语言模型的训练提供了新的思路。
多token预测利用了语言的长程依赖关系,通过同时预测多个未来的token,促使模型学习更全面、更连贯的表示。这种方法与人类语言学习的过程更为相似,因为我们在理解和生成语言时,也是基于对未来一段文本的预期,而不仅仅依赖于前一个词。
该方法在编程任务上取得了显著的性能提升,这可能是因为编程语言具有更强的结构性和逻辑性,多token预测更容易捕捉到其中的模式和依赖关系。在自然语言任务上的改进相对较小,可能是因为自然语言的不确定性和灵活性更高,单纯增加预测的token数量效果有限,需要更细致的建模方法。
多token预测在推论阶段带来的加速效果非常可观,这对于实际应用中的延迟敏感场景(如实时对话、同步翻译等)具有重要价值。不过,这种加速方法对模型性能的影响还需要进一步评估,确保生成质量不会显著下降。
论文中的实验主要集中在编程和自然语言文本上,未来可以考虑将多token预测应用于其他类型的序列数据,如时间序列、生物序列等,探索它在更广泛领域的有效性。
多token预测作为一种辅助的训练目标,与其他方法(如对比学习、知识蒸馏等)结合使用,可能会产生更好的协同效果。探索多种训练策略的组合,有望进一步提升语言模型的性能和泛化能力。
我认为这项工作为改进大型语言模型的训练和推理效率提供了一个简单而有效的思路,具有广阔的应用前景。未来可以在更大规模的数据集和模型上验证这种方法的有效性,并探索与其他技术结合的可能性,推动语言模型的进一步发展。