Meta更低的训练成本取得更好的性能: 多token预测(Multi-Token Prediction)

Meta提出了一种透过多token预测(Multi-token Prediction)来训练更好、更快的大型语言模型的方法。这篇论文的重点如下:

训练语言模型同时预测多个未来的token,可以提高样本效率(sample efficiency)。
在推论阶段,使用多token预测可以达到最高3倍的加速。

在这里插入图片描述

论文的主要贡献包括:

  • 提出了一种简单的多token预测架构,在训练时间和内存使用上没有额外开销。
    实验证明,这种训练范式在大规模模型(最高达130亿参数)上是有效的,平均可以解决大约15%以上的编程问题
  • 多token预测使得自我推测解碼(self-speculative decoding)成为可能,在各种批次大小下将模型的推论速度提高了最多3倍。

https://arxiv.org/pdf/2404.19737

动机与目的

传统的语言模型通常使用下一个token预测(next-token prediction)的方式进行训练,即根据前面的token序列,预测下一个最可能出现的token。然而,这种训练方式可能导致模型过度关注局部的模式,忽略了长程的依赖关系。为了解决这个问题,本文提出了多token预测(multi-token prediction)的训练方法,同时预测未来的多个token,以提升语言模型的训练效率和性能。

在这里插入图片描述

方法原理

模型架构

语言模型使用一个共享的模型主体(shared model trunk),并在其上添加n个独立的输出头(output head),分别预测未来的n个token。
在训练时,模型在每个位置同时预测未来的n个token,使用n个独立的loss项。
为了减少GPU内存用量,作者巧妙地调整了前向/反向传播的顺序。模型依序计算每个输出头的前向和反向传播,同时累积主体的梯度,避免同时储存所有n个庞大的logit向量。
推论时,可以只用第一个输出头(也就是下一个token的预测),其余输出头可选择性地用于加速推论(称为self-speculative decoding)。

在这里插入图片描述

训练目标

在训练时,模型在每个位置同时预测未来的 n n n个token,使用 n n n个独立的cross-entropy loss项。假设输入的token序列为 x 1 , x 2 , . . . , x t , x_1, x_2, ..., x_t, x1,x2,...,xt,模型的训练目标可以表示为:

L n = − Σ t l o g P ( x t + 1 , . . . , x t + n ∣ x 1 , . . . , x t ) L_n = - Σ_t log P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t) Ln=ΣtlogP(xt+1,...,xt+nx1,...,xt)

其中, P ( x t + 1 , . . . , x t + n ∣ x 1 , . . . , x t ) P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t) P(xt+1,...,xt+nx1,...,xt)表示在给定前 t t t个token的条件下,未来 n n n个token的联合概率分布。将这个联合概率分解为 n n n个条件概率的乘积,可以得到:

L n = − Σ t [ l o g P ( x t + 1 ∣ x 1 , . . . , x t ) + l o g P ( x t + 2 ∣ x 1 , . . . , x t ) + . . . + l o g P ( x t + n ∣ x 1 , . . . , x t ) L_n = - Σ_t [log P(x_{t+1} | x_1, ..., x_t) + log P(x_{t+2} | x_1, ..., x_t) + ... + log P(x_{t+n} | x_1, ..., x_t) Ln=Σt[logP(xt+1x1,...,xt)+logP(xt+2x1,...,xt)+...+logP(xt+nx1,...,xt))

每个条件概率 P ( x t + i ∣ x 1 , . . . , x t ) P(x_{t+i} | x_1, ..., x_t) P(xt+ix1,...,xt)由一个独立的输出头计算得到。

训练技巧

为了减少GPU内存的使用量,作者巧妙地调整了前向/反向传播的顺序。模型依序计算每个输出头的前向和反向传播,同时累积主体的梯度,避免同时储存所有n个庞大的logit向量。这种技巧使得多token预测模型的训练几乎不增加额外的计算和存储开销。

在这里插入图片描述

推论过程

在推论阶段,可以只使用第一个输出头(即下一个token的预测),其余输出头可选择性地用于加速推论。这种加速技术称为self-speculative decoding,通过并行计算多个输出头的预测结果,可以提高推论的效率。

实验结果

作者在多个编码和自然语言任务上评估了多token预测模型的性能,并与传统的下一个token预测模型进行了比较。

在这里插入图片描述

编码任务

在HumanEval和MBPP两个编码数据集上,多token预测模型显著优于基准模型,尤其在大模型(如13B参数)上提升更加明显。4个token的预测在综合表现上最佳,在HumanEval上pass@100提升了4.1%,在MBPP上pass@1提升了3.8%。此外,训练多个epoch时,多token预测的优势仍然存在。

自然语言任务

在自然语言任务上,多token预测也带来了改进,特别是在需要生成较长文本的摘要和自然语言数学任务。在8个摘要数据集上,2个token的预测平均将ROUGE-L提升了0.51,4个token的预测平均提升了0.46。在GSM8K自然语言数学数据集上,2个token的预测模型显著优于基准模型。

字符级训练

在这里插入图片描述

为了验证多token预测有助于学习更长程的依赖关系,作者进行了字符级(byte-level)的训练实验。结果表明,8个字符的多token预测模型在HumanEval上pass@1的表现比下一个字符预测模型高出20%,在MBPP上高出67%。这说明多token预测能够捕捉更长距离的模式和依赖关系。

模型微调

使用预训练的多token预测模型进行微调,也能在下游任务上取得优于基准模型的成果。在CodeContests数据集上,4个token预训练的模型在pass@k上全面超过了下一个token预训练的模型。

在这里插入图片描述

在编码(coding)任务上,多token预测模型在HumanEval和MBPP数据集上的表现显著优于基准模型,尤其在大模型(如13B参数)上提升更加明显。
在自然语言任务上,多token预测也带来了改进,特别是在需要生成较长文本的摘要和自然语言数学任务。
多token预测有助于模型学习更长程的依赖关系。在字符级(byte-level)的训练中,8个字符的多token预测大幅优于下一个字符预测。
实验显示,4个token的预测在综合表现上最佳。此外,训练多个epoch时,多token预测的优势仍然存在。
使用训练好的多token预测模型进行微调(如在CodeContests数据集上),也能取得优于基准模型的成果。
额外的输出头可用于self-speculative decoding,在推论阶段提供最高3倍的加速。

在这里插入图片描述

结论与讨论

本文提出了一种简单而有效的语言模型训练方法——多token预测,通过同时预测未来的多个token,促进模型学习更长程的依赖关系。实验结果表明,这种方法在编码和自然语言任务上带来了显著的性能提升,尤其对大模型和较长文本的生成任务效果更佳。多token预测几乎不增加训练成本,却能提高训练和推论效率,值得进一步探索。

在这里插入图片描述

作者认为,这项工作为寻找更有效的语言模型训练方法开辟了新的方向。未来的研究可以探索以下几个方面:

  1. 在更大规模的数据集和模型上验证多token预测的有效性。
  2. 研究最优的token预测数量n,以及如何自适应地选择n。
  3. 设计更高效的多token预测架构,如使用单一的输出头来预测多个token。
  4. 将多token预测与其他辅助训练目标结合,如掩码语言建模(masked language modeling)。在这里插入图片描述
    在这里插入图片描述

多token预测是一种前景广阔的语言模型训练方法,有望帮助构建更强大、更连贯的语言模型,推动自然语言处理领域的发展。

以下是我对这项工作的一些想法:

Meta最近提出了一种简单而有效的语言模型训练方法—多token预测(Multi-Token Prediction,简称MTP)。传统的语言模型通常每次只预测一个token,而MTP则在每个时间步预测多个token,从而提高训练效率。
核心思想:

在每个时间步,模型预测接下来的n个token,而不是1个
将这n个token打包成一个单独的预测目标,用一个特殊的分隔符隔开
模型的输出是长度为n的token序列,用交叉熵损失函数优化

优点:

预测多个token,捕捉更长距离的依赖,学到更强的上下文表征
并行化程度高,加快训练速度,节省显存
实现简单,几乎不增加模型参数量
在下游任务上finetune,相比传统方法能取得更好的效果

实验结果表明,相比标准的next token prediction,MTP能以更低的训练成本取得更好的性能。比如在相同的计算预算下,MTP的WikiText-103困惑度比传统方法低15%以上。
总之,多token预测是一种简洁而强大的语言模型训练范式。通过预测多个token,它能学到更丰富的上下文信息。同时并行化程度高,训练高效。Meta的这项工作为语言模型的训练提供了新的思路。

多token预测利用了语言的长程依赖关系,通过同时预测多个未来的token,促使模型学习更全面、更连贯的表示。这种方法与人类语言学习的过程更为相似,因为我们在理解和生成语言时,也是基于对未来一段文本的预期,而不仅仅依赖于前一个词。

该方法在编程任务上取得了显著的性能提升,这可能是因为编程语言具有更强的结构性和逻辑性,多token预测更容易捕捉到其中的模式和依赖关系。在自然语言任务上的改进相对较小,可能是因为自然语言的不确定性和灵活性更高,单纯增加预测的token数量效果有限,需要更细致的建模方法。

多token预测在推论阶段带来的加速效果非常可观,这对于实际应用中的延迟敏感场景(如实时对话、同步翻译等)具有重要价值。不过,这种加速方法对模型性能的影响还需要进一步评估,确保生成质量不会显著下降。

论文中的实验主要集中在编程和自然语言文本上,未来可以考虑将多token预测应用于其他类型的序列数据,如时间序列、生物序列等,探索它在更广泛领域的有效性。

多token预测作为一种辅助的训练目标,与其他方法(如对比学习、知识蒸馏等)结合使用,可能会产生更好的协同效果。探索多种训练策略的组合,有望进一步提升语言模型的性能和泛化能力。

我认为这项工作为改进大型语言模型的训练和推理效率提供了一个简单而有效的思路,具有广阔的应用前景。未来可以在更大规模的数据集和模型上验证这种方法的有效性,并探索与其他技术结合的可能性,推动语言模型的进一步发展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/8713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Django中如何使用WebSocket实时更新数据?

在Django中使用WebSocket实时更新数据,可以通过使用第三方库Django Channels实现。Django Channels是基于WebSocket的实时通信框架,它使得Django应用可以处理实时的、异步的任务。 下面是使用Django Channels实时更新数据的一般步骤: 安装D…

ES集群数据备份与迁移

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、文章涉及概念讲解二、操作步骤1.创建 snapshot repository操作主机hadoop1分别操作从机hadoop2和hadoop3 2. 查看仓库信息3. 备份索引,生成快照…

【S32K UDS BootLoader】-1.1-Unified bootloader Demo和ECUBus工具的使用

<--返回「Autosar_MCAL高阶配置」专栏主页--> 目录 1 下载S32K1/S32K3/S12Z Unified bootloader Demo 1.1 在S32DS中编译S32K312_CAN_bootloader_RTD2d0工程并烧录 2 ECUBus工具使用 2.1 PCAN环境搭建 1.1.1 安装PCAN驱动 1.1.2 安装PCAN-View 2.2 下载并安装ECU…

蓝桥杯 BASIC-26 基础练习 报时助手

蓝桥杯 BASIC-26 基础练习 报时助手 问题描述 给定当前的时间&#xff0c;请用英文的读法将它读出来。 时间用时h和分m表示&#xff0c;在英文的读法中&#xff0c;读一个时间的方法是&#xff1a; 如果m为0&#xff0c;则将时读出来&#xff0c;然后加上“o’clock”&#xff…

嵌入式C语言的变量和函数存储类型

目录 概述 1 嵌入式C的数据类型 2 嵌入式C语言存储类型 2.1 auto存储类型 2.2 extern存储类型 2.3 register存储类型 2.4 static存储类型 概述 本文主要介绍嵌入式C语言中的数据变量的类型&#xff0c;包括其数据长度&#xff0c;在内存中的存储方式。还介绍了数据的存储…

C语言 | Leetcode C语言题解之第77题组合

题目&#xff1a; 题解&#xff1a; int** combine(int n, int k, int* returnSize, int** returnColumnSizes) {int* temp malloc(sizeof(int) * (k 1));int tempSize 0;int** ans malloc(sizeof(int*) * 200001);int ansSize 0;// 初始化// 将 temp 中 [0, k - 1] 每个…

Vue项目中使用echarts教程

Vue项目中使用echarts教程 步骤npm 安装ECharts引入 ECharts老版本引入方式 &#xff08;v4版本&#xff09;新版本引入方式 &#xff08;v5版本&#xff09; ECharts初体验ECharts组件化&#xff08;进阶写法&#xff09; 步骤 npm 安装ECharts npm install echarts --save引…

回答篇:测试开发高频面试题目

引用之前文章&#xff1a;《测试开发高频面试题目》 https://blog.csdn.net/qq_41214208/article/details/138193469?spm1001.2014.3001.5502 本篇文章是回答篇&#xff08;持续更新中&#xff09; 1. 什么是测试开发以及其在软件开发流程中的作用。 a. 测试开发是指测试人员或…

关于Anaconda常用的命令

常用命令 查看当前环境下的环境&#xff1a;conda env list查看当前conda的版本&#xff1b;conda --version conda create -n your_env_name pythonX.X&#xff08;2.7、3.6等)命令创建python版本为X.X。名字为your_env_name的虚拟环境。your_env_name文件可以在Anaconda安装…

收银系统源码--什么是千呼智慧新零售系统?

千呼智慧新零售系统是一套针对零售行业线上线下一体化收银系统。给门店提供线下称重收银、o2o线上商城、erp进销存、精细化会员管理、丰富营销插件等一体化解决方案。多端数据打通&#xff0c;实现线上线下一体化&#xff0c;提升门店工作效率&#xff0c;实现数字化升级&#…

前端项目加载离线的百度地图,利用工具进行切指定区域的地图影像,自定义图层getTilesUrl

百度地图在开发中我们经常使用&#xff0c;但是有些项目是需要在内网进行&#xff0c;这时候我们不得不考虑项目中一些功能需要请求外网静态资源&#xff0c;比如百度地图。只有把包下载到本地&#xff0c;才能让静态资源文件的正常的访问。 目录 获取百度地图开发秘钥 引入在…

设计模式——装饰者模式(Decorator)

装饰者模式&#xff08;Decorator Pattern&#xff09;是一种结构型设计模式&#xff0c;它允许你动态地给一个对象添加一些额外的职责&#xff0c;就增加功能来说&#xff0c;装饰者模式相比生成子类更为灵活。在装饰者模式中&#xff0c;一个装饰类会包装一个对象&#xff08…

Transformer优化加速--xformers

一、定义 1 作用 2 优化创新点 3. 使用demo 二、实现 作用 facebook 提出&#xff0c; xformers能够有效加速attention计算并降低显存。 参考&#xff1a; https://github.com/facebookresearch/xformers https://zhuanlan.zhihu.com/p/688745007 接口&#xff1a;https://f…

Java | Leetcode Java题解之第78题子集

题目&#xff1a; 题解&#xff1a; class Solution {List<Integer> t new ArrayList<Integer>();List<List<Integer>> ans new ArrayList<List<Integer>>();public List<List<Integer>> subsets(int[] nums) {dfs(0, nums…

C++容器——map和pair对组

pair&#xff08;对组&#xff09; 是一种模板类&#xff0c;允许将两个不同类型的值组合在一起。它由两个数据成员first和second组成&#xff0c;分别用来保存这两个值。 头文件 加头文件 #include<utility> 对于 C11 及以上标准&#xff0c;pair 类型可以在不包含头…

牛客网刷题 | BC81 KiKi求质数个数

目前主要分为三个专栏&#xff0c;后续还会添加&#xff1a; 专栏如下&#xff1a; C语言刷题解析 C语言系列文章 我的成长经历 感谢阅读&#xff01; 初来乍到&#xff0c;如有错误请指出&#xff0c;感谢&#xff01; 描述 KiKi知道了什么是质…

【离散数学】集合上二元关系性质判定的实现(c语言实现)

实验要求 关系矩阵的初始化和打印 我们将关系矩阵存入一个二维数组中&#xff0c;因为集合元素个数不会超过5个所以就用一个5行5列二维数组来表示。 在我们得到了集合元素个数之后我们就可以对数组进行0,1随机赋值 //初始关系矩阵 void init_matrix(int array[][5], int n) {…

python使用f-string时如何保留原始的{}

如果想在 f-string 中使用 {} 符号&#xff0c;但又不想让它被解释成 f-string 的占位符&#xff0c;可以使用两个连续的 {} 来表示一个单独的 {} 符号&#xff0c;从而使其保留原始的形式。 例如&#xff1a; name "John" age 30 text f"{{Hello {name}, …

力扣:1005. K 次取反后最大化的数组和

1005. K 次取反后最大化的数组和 给你一个整数数组 nums 和一个整数 k &#xff0c;按以下方法修改该数组&#xff1a; 选择某个下标 i 并将 nums[i] 替换为 -nums[i] 。 重复这个过程恰好 k 次。可以多次选择同一个下标 i 。 以这种方式修改数组后&#xff0c;返回数组 可能…

多核DSP并行计算跨平台通信解决方案

并行计算的核心是计算节点以及节点间的通信与协调机制。OpenMP虽然给开发者提供了极易上手的增量式开发方式&#xff0c;但是OpenMP在与复杂架构的MCSDK结合后&#xff0c;工具与代码产生了大量不可调试的黑盒子&#xff0c;更是决定了它不能用于关键任务领域&#xff0c;如军工…