Meta更低的训练成本取得更好的性能: 多token预测(Multi-Token Prediction)

Meta提出了一种透过多token预测(Multi-token Prediction)来训练更好、更快的大型语言模型的方法。这篇论文的重点如下:

训练语言模型同时预测多个未来的token,可以提高样本效率(sample efficiency)。
在推论阶段,使用多token预测可以达到最高3倍的加速。

在这里插入图片描述

论文的主要贡献包括:

  • 提出了一种简单的多token预测架构,在训练时间和内存使用上没有额外开销。
    实验证明,这种训练范式在大规模模型(最高达130亿参数)上是有效的,平均可以解决大约15%以上的编程问题
  • 多token预测使得自我推测解碼(self-speculative decoding)成为可能,在各种批次大小下将模型的推论速度提高了最多3倍。

https://arxiv.org/pdf/2404.19737

动机与目的

传统的语言模型通常使用下一个token预测(next-token prediction)的方式进行训练,即根据前面的token序列,预测下一个最可能出现的token。然而,这种训练方式可能导致模型过度关注局部的模式,忽略了长程的依赖关系。为了解决这个问题,本文提出了多token预测(multi-token prediction)的训练方法,同时预测未来的多个token,以提升语言模型的训练效率和性能。

在这里插入图片描述

方法原理

模型架构

语言模型使用一个共享的模型主体(shared model trunk),并在其上添加n个独立的输出头(output head),分别预测未来的n个token。
在训练时,模型在每个位置同时预测未来的n个token,使用n个独立的loss项。
为了减少GPU内存用量,作者巧妙地调整了前向/反向传播的顺序。模型依序计算每个输出头的前向和反向传播,同时累积主体的梯度,避免同时储存所有n个庞大的logit向量。
推论时,可以只用第一个输出头(也就是下一个token的预测),其余输出头可选择性地用于加速推论(称为self-speculative decoding)。

在这里插入图片描述

训练目标

在训练时,模型在每个位置同时预测未来的 n n n个token,使用 n n n个独立的cross-entropy loss项。假设输入的token序列为 x 1 , x 2 , . . . , x t , x_1, x_2, ..., x_t, x1,x2,...,xt,模型的训练目标可以表示为:

L n = − Σ t l o g P ( x t + 1 , . . . , x t + n ∣ x 1 , . . . , x t ) L_n = - Σ_t log P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t) Ln=ΣtlogP(xt+1,...,xt+nx1,...,xt)

其中, P ( x t + 1 , . . . , x t + n ∣ x 1 , . . . , x t ) P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t) P(xt+1,...,xt+nx1,...,xt)表示在给定前 t t t个token的条件下,未来 n n n个token的联合概率分布。将这个联合概率分解为 n n n个条件概率的乘积,可以得到:

L n = − Σ t [ l o g P ( x t + 1 ∣ x 1 , . . . , x t ) + l o g P ( x t + 2 ∣ x 1 , . . . , x t ) + . . . + l o g P ( x t + n ∣ x 1 , . . . , x t ) L_n = - Σ_t [log P(x_{t+1} | x_1, ..., x_t) + log P(x_{t+2} | x_1, ..., x_t) + ... + log P(x_{t+n} | x_1, ..., x_t) Ln=Σt[logP(xt+1x1,...,xt)+logP(xt+2x1,...,xt)+...+logP(xt+nx1,...,xt))

每个条件概率 P ( x t + i ∣ x 1 , . . . , x t ) P(x_{t+i} | x_1, ..., x_t) P(xt+ix1,...,xt)由一个独立的输出头计算得到。

训练技巧

为了减少GPU内存的使用量,作者巧妙地调整了前向/反向传播的顺序。模型依序计算每个输出头的前向和反向传播,同时累积主体的梯度,避免同时储存所有n个庞大的logit向量。这种技巧使得多token预测模型的训练几乎不增加额外的计算和存储开销。

在这里插入图片描述

推论过程

在推论阶段,可以只使用第一个输出头(即下一个token的预测),其余输出头可选择性地用于加速推论。这种加速技术称为self-speculative decoding,通过并行计算多个输出头的预测结果,可以提高推论的效率。

实验结果

作者在多个编码和自然语言任务上评估了多token预测模型的性能,并与传统的下一个token预测模型进行了比较。

在这里插入图片描述

编码任务

在HumanEval和MBPP两个编码数据集上,多token预测模型显著优于基准模型,尤其在大模型(如13B参数)上提升更加明显。4个token的预测在综合表现上最佳,在HumanEval上pass@100提升了4.1%,在MBPP上pass@1提升了3.8%。此外,训练多个epoch时,多token预测的优势仍然存在。

自然语言任务

在自然语言任务上,多token预测也带来了改进,特别是在需要生成较长文本的摘要和自然语言数学任务。在8个摘要数据集上,2个token的预测平均将ROUGE-L提升了0.51,4个token的预测平均提升了0.46。在GSM8K自然语言数学数据集上,2个token的预测模型显著优于基准模型。

字符级训练

在这里插入图片描述

为了验证多token预测有助于学习更长程的依赖关系,作者进行了字符级(byte-level)的训练实验。结果表明,8个字符的多token预测模型在HumanEval上pass@1的表现比下一个字符预测模型高出20%,在MBPP上高出67%。这说明多token预测能够捕捉更长距离的模式和依赖关系。

模型微调

使用预训练的多token预测模型进行微调,也能在下游任务上取得优于基准模型的成果。在CodeContests数据集上,4个token预训练的模型在pass@k上全面超过了下一个token预训练的模型。

在这里插入图片描述

在编码(coding)任务上,多token预测模型在HumanEval和MBPP数据集上的表现显著优于基准模型,尤其在大模型(如13B参数)上提升更加明显。
在自然语言任务上,多token预测也带来了改进,特别是在需要生成较长文本的摘要和自然语言数学任务。
多token预测有助于模型学习更长程的依赖关系。在字符级(byte-level)的训练中,8个字符的多token预测大幅优于下一个字符预测。
实验显示,4个token的预测在综合表现上最佳。此外,训练多个epoch时,多token预测的优势仍然存在。
使用训练好的多token预测模型进行微调(如在CodeContests数据集上),也能取得优于基准模型的成果。
额外的输出头可用于self-speculative decoding,在推论阶段提供最高3倍的加速。

在这里插入图片描述

结论与讨论

本文提出了一种简单而有效的语言模型训练方法——多token预测,通过同时预测未来的多个token,促进模型学习更长程的依赖关系。实验结果表明,这种方法在编码和自然语言任务上带来了显著的性能提升,尤其对大模型和较长文本的生成任务效果更佳。多token预测几乎不增加训练成本,却能提高训练和推论效率,值得进一步探索。

在这里插入图片描述

作者认为,这项工作为寻找更有效的语言模型训练方法开辟了新的方向。未来的研究可以探索以下几个方面:

  1. 在更大规模的数据集和模型上验证多token预测的有效性。
  2. 研究最优的token预测数量n,以及如何自适应地选择n。
  3. 设计更高效的多token预测架构,如使用单一的输出头来预测多个token。
  4. 将多token预测与其他辅助训练目标结合,如掩码语言建模(masked language modeling)。在这里插入图片描述
    在这里插入图片描述

多token预测是一种前景广阔的语言模型训练方法,有望帮助构建更强大、更连贯的语言模型,推动自然语言处理领域的发展。

以下是我对这项工作的一些想法:

Meta最近提出了一种简单而有效的语言模型训练方法—多token预测(Multi-Token Prediction,简称MTP)。传统的语言模型通常每次只预测一个token,而MTP则在每个时间步预测多个token,从而提高训练效率。
核心思想:

在每个时间步,模型预测接下来的n个token,而不是1个
将这n个token打包成一个单独的预测目标,用一个特殊的分隔符隔开
模型的输出是长度为n的token序列,用交叉熵损失函数优化

优点:

预测多个token,捕捉更长距离的依赖,学到更强的上下文表征
并行化程度高,加快训练速度,节省显存
实现简单,几乎不增加模型参数量
在下游任务上finetune,相比传统方法能取得更好的效果

实验结果表明,相比标准的next token prediction,MTP能以更低的训练成本取得更好的性能。比如在相同的计算预算下,MTP的WikiText-103困惑度比传统方法低15%以上。
总之,多token预测是一种简洁而强大的语言模型训练范式。通过预测多个token,它能学到更丰富的上下文信息。同时并行化程度高,训练高效。Meta的这项工作为语言模型的训练提供了新的思路。

多token预测利用了语言的长程依赖关系,通过同时预测多个未来的token,促使模型学习更全面、更连贯的表示。这种方法与人类语言学习的过程更为相似,因为我们在理解和生成语言时,也是基于对未来一段文本的预期,而不仅仅依赖于前一个词。

该方法在编程任务上取得了显著的性能提升,这可能是因为编程语言具有更强的结构性和逻辑性,多token预测更容易捕捉到其中的模式和依赖关系。在自然语言任务上的改进相对较小,可能是因为自然语言的不确定性和灵活性更高,单纯增加预测的token数量效果有限,需要更细致的建模方法。

多token预测在推论阶段带来的加速效果非常可观,这对于实际应用中的延迟敏感场景(如实时对话、同步翻译等)具有重要价值。不过,这种加速方法对模型性能的影响还需要进一步评估,确保生成质量不会显著下降。

论文中的实验主要集中在编程和自然语言文本上,未来可以考虑将多token预测应用于其他类型的序列数据,如时间序列、生物序列等,探索它在更广泛领域的有效性。

多token预测作为一种辅助的训练目标,与其他方法(如对比学习、知识蒸馏等)结合使用,可能会产生更好的协同效果。探索多种训练策略的组合,有望进一步提升语言模型的性能和泛化能力。

我认为这项工作为改进大型语言模型的训练和推理效率提供了一个简单而有效的思路,具有广阔的应用前景。未来可以在更大规模的数据集和模型上验证这种方法的有效性,并探索与其他技术结合的可能性,推动语言模型的进一步发展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/8713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ES集群数据备份与迁移

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、文章涉及概念讲解二、操作步骤1.创建 snapshot repository操作主机hadoop1分别操作从机hadoop2和hadoop3 2. 查看仓库信息3. 备份索引,生成快照…

【S32K UDS BootLoader】-1.1-Unified bootloader Demo和ECUBus工具的使用

<--返回「Autosar_MCAL高阶配置」专栏主页--> 目录 1 下载S32K1/S32K3/S12Z Unified bootloader Demo 1.1 在S32DS中编译S32K312_CAN_bootloader_RTD2d0工程并烧录 2 ECUBus工具使用 2.1 PCAN环境搭建 1.1.1 安装PCAN驱动 1.1.2 安装PCAN-View 2.2 下载并安装ECU…

C语言 | Leetcode C语言题解之第77题组合

题目&#xff1a; 题解&#xff1a; int** combine(int n, int k, int* returnSize, int** returnColumnSizes) {int* temp malloc(sizeof(int) * (k 1));int tempSize 0;int** ans malloc(sizeof(int*) * 200001);int ansSize 0;// 初始化// 将 temp 中 [0, k - 1] 每个…

回答篇:测试开发高频面试题目

引用之前文章&#xff1a;《测试开发高频面试题目》 https://blog.csdn.net/qq_41214208/article/details/138193469?spm1001.2014.3001.5502 本篇文章是回答篇&#xff08;持续更新中&#xff09; 1. 什么是测试开发以及其在软件开发流程中的作用。 a. 测试开发是指测试人员或…

关于Anaconda常用的命令

常用命令 查看当前环境下的环境&#xff1a;conda env list查看当前conda的版本&#xff1b;conda --version conda create -n your_env_name pythonX.X&#xff08;2.7、3.6等)命令创建python版本为X.X。名字为your_env_name的虚拟环境。your_env_name文件可以在Anaconda安装…

收银系统源码--什么是千呼智慧新零售系统?

千呼智慧新零售系统是一套针对零售行业线上线下一体化收银系统。给门店提供线下称重收银、o2o线上商城、erp进销存、精细化会员管理、丰富营销插件等一体化解决方案。多端数据打通&#xff0c;实现线上线下一体化&#xff0c;提升门店工作效率&#xff0c;实现数字化升级&#…

前端项目加载离线的百度地图,利用工具进行切指定区域的地图影像,自定义图层getTilesUrl

百度地图在开发中我们经常使用&#xff0c;但是有些项目是需要在内网进行&#xff0c;这时候我们不得不考虑项目中一些功能需要请求外网静态资源&#xff0c;比如百度地图。只有把包下载到本地&#xff0c;才能让静态资源文件的正常的访问。 目录 获取百度地图开发秘钥 引入在…

Java | Leetcode Java题解之第78题子集

题目&#xff1a; 题解&#xff1a; class Solution {List<Integer> t new ArrayList<Integer>();List<List<Integer>> ans new ArrayList<List<Integer>>();public List<List<Integer>> subsets(int[] nums) {dfs(0, nums…

牛客网刷题 | BC81 KiKi求质数个数

目前主要分为三个专栏&#xff0c;后续还会添加&#xff1a; 专栏如下&#xff1a; C语言刷题解析 C语言系列文章 我的成长经历 感谢阅读&#xff01; 初来乍到&#xff0c;如有错误请指出&#xff0c;感谢&#xff01; 描述 KiKi知道了什么是质…

【离散数学】集合上二元关系性质判定的实现(c语言实现)

实验要求 关系矩阵的初始化和打印 我们将关系矩阵存入一个二维数组中&#xff0c;因为集合元素个数不会超过5个所以就用一个5行5列二维数组来表示。 在我们得到了集合元素个数之后我们就可以对数组进行0,1随机赋值 //初始关系矩阵 void init_matrix(int array[][5], int n) {…

多核DSP并行计算跨平台通信解决方案

并行计算的核心是计算节点以及节点间的通信与协调机制。OpenMP虽然给开发者提供了极易上手的增量式开发方式&#xff0c;但是OpenMP在与复杂架构的MCSDK结合后&#xff0c;工具与代码产生了大量不可调试的黑盒子&#xff0c;更是决定了它不能用于关键任务领域&#xff0c;如军工…

算法学习Day2——单调栈习题

第一题&#xff0c;合并球 题解&#xff1a;一开始写了一次暴力双循环&#xff0c;直接O(n^2)严重超时&#xff0c;后面于是又想到了O(n)时间复杂度的链表&#xff0c;但是还是卡在 最后一个数据会TLE&#xff0c;我也是高兴的拍起来安塞腰鼓和华氏护肤水&#xff0c;后面学长给…

基于模糊控制的AMT自动变速汽车换档智能控制系统simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于模糊控制的AMT自动变速汽车换档智能控制系统simulink建模与仿真。 2.系统仿真结果 输入的V&#xff0c;Ac&#xff0c;a 输出的档位&#xff1a; 3.核心程序与模型 版…

【C语言】static关键字用法

目录 一、static修饰局部变量 二、static修饰全局变量 三、static修饰函数 一、static修饰局部变量 首先我们来看两段代码: 代码1&#xff08;不加static&#xff09; #include <stdio.h> void test() {int i 0;i;printf("%d ", i); } int main() {int i…

VMvare如何更改虚拟机内共享文件夹的挂载点

更改虚拟机内共享文件夹的路径 进入目录 /etc/init.d ,并找到vmware-tools文件 里面有配置项 vmhgfs_mnt"/mnt/hgfs" 将引号内的内容更改为你需要挂载的路径,重启即可 注意挂载的路径不能是 “/”&#xff0c;必须根目录下的某个文件夹&#xff0c;或者其子文件夹 …

使用Docker安装Yapi接口管理工具

简介&#xff1a; YAPI 是由去哪儿网移动架构组开发的一款可视化接口管理工具。它具有可视化管理、高效易用、功能强大等特点。它提供了便捷的接口创建、发布和维护方式&#xff0c;开发人员可以通过简单的操作实现接口管理。 YAPI 还支持类似 postman 的接口调试&#xff0c;对…

GPU通用计算介绍

谈到 GPU &#xff08;Graphics Processing Unit&#xff0c;图形显示卡&#xff09;大多数人想到的是游戏、图形渲染等这些词汇&#xff0c;图形处理确实是 GPU 的一大应用场景。然而人们也早已关注到它在通用计算上的巨大潜力&#xff0c;并提出了 GPGPU (General-purpose co…

Android进阶之路 - 静态会员进度条

年后这个新版本加入了VIP模块&#xff0c;有幸正好由我来负责&#xff0c;可以再积累一下这方面的知识。 那段时间看了一本书&#xff0c;书中说到初级码农的特性之一就是完全集中于某些功能&#xff0c;忽略了了很多成长机会&#xff0c;所以重复性劳作带来的成长值有限&#…

ETL工具中JSON格式的转换方式

JSON的用处 JSON&#xff08;JavaScript Object Notation&#xff09;是一种轻量级的数据交换格式&#xff0c;其设计初衷是为了提升网络应用中数据的传输效率及简化数据结构的解析过程。自其诞生以来&#xff0c;JSON 已成为Web开发乃至众多软件开发领域中不可或缺的一部分&a…

神经网络案例实战

&#x1f50e;我们通过一个案例详细使用PyTorch实战 &#xff0c;案例背景&#xff1a;你创办了一家手机公司&#xff0c;不知道如何估算手机产品的价格。为了解决这个问题&#xff0c;收集了多家公司的手机销售数据&#xff1a;这些数据维度可以包括RAM、存储容量、屏幕尺寸、…