LLaMA详细解读

LLaMA 是目前为止,效果最好的开源 LLM 之一。精读 LLaMA 的论文及代码,可以很好的了解 LLM 的内部原理。本文对 LLaMA 论文进行了介绍,同时附上了关键部分的代码,并对代码做了注释。

摘要

LLaMA是一个系列模型,模型参数量从7B到65B。在大部分的任务上,LLaMA-13B强于GPT-3(175B)。LLaMA-65B的性能,可以和最好的LM相媲美,如Chinchilla-70B 和 PaLM-540B。

一、引言

一般而言,模型越大,效果越好。然而有文献指出[1],当给定计算量的预算之后,最好的performance,并不是最大的模型,而是在一个小模型上用更多的数据进行训练。针对给定的计算量预算,scaling laws可以计算如何选择数据量的大小和模型的大小。然而这忽略了inference的预算,而这一点在模型推理时非常关键。当给定一个模型performance目标之后,最好的模型不是训练最快的模型,而是推理最快的模型。尽管在这种情况下,训练一个更大的模型成本会更低。

文献[2]中推荐,训练一个 10B 的模型,需要 200B 的 tokens,而本文的实验发现,一个7B的模型,经过 1T tokens 训练之后,performance 仍然在增加。本文的目标在于,通过在超大规模的数据上训练,给出一系列可能最好 performance 的 LLM。

二、预训练数据

2.1 数据集

一共有1.4T的tokens,大部分的训练数据都只用了一次,除了Wikipedia 和 Books 使用了大概2个epochs。

Pre-training data

2.2 tokenizer

使用byte pair encoding (BPE) 算法,使用的是Sentence-Piece的实现。所有数字被拆分为单独的digit,所有未知的UTF-8 字符,回退到字节来进行分解。因此,LLaMA 可以通过byte 的方式,构造出很多不在 vocab 中的字符,从而也具有较好的多语言能力。

三、网络结构改进

使用了基于transformer的架构,并做了如下3点改进:

3.1 Pre-normalization

为了提高训练的稳定性,对每个transformer层的输入进行归一化,而不是输出进行归一化。

同时,使用 RMS Norm 归一化函数。RMS Norm 的全称为 Root Mean Square layer normalization。与 layer Norm 相比,RMS Norm的主要区别在于去掉了减去均值的部分,计算公式为:

RMS Norm 的作者认为这种模式在简化了Layer Norm 的计算,可以在减少约 7%∼64% 的计算时间[3]。

class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypevariance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return (self.weight * hidden_states).to(input_dtype)

3.2 SwiGLU

使用SwiGLU替代了ReLU作为激活函数。和PaLM中不同,维度采用而不是 4𝑑 。

SwiGLU 在论文[4] 中提出,相比于其他的激活函数变体,可以取得 log-perplexity 的最优值(和 GEGLU 并列)。

GLU Variants Improve Transformer

SwiGLU 及几种类似变体的计算公式如下:

其中,。代码如下:

class LlamaMLP(nn.Module):def __init__(self,hidden_size: int,intermediate_size: int,hidden_act: str,):super().__init__()self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)# config 中 hidden_act = 'silu'# 'silu' 和 'swish' 对应的激活函数均为:SiLUActivation # https://github.com/huggingface/transformers/blob/717dadc6f36be9f50abc66adfd918f9b0e6e3502/src/transformers/activations.py#L229self.act_fn = ACT2FN[hidden_act]def forward(self, x):# 对应上述公式的 SwiGLUreturn self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

从代码可以看到 LlamaMLP 中一共有 3 个 Linear 层,原因就在于 SwiGLU 激活函数比类似 ReLU 的激活函数,需要多一个 Linear 层进行门控。

3.3 RoPE

RoPE 的核心思想是“通过绝对位置编码的方式实现相对位置编码”,可以说是具备了绝对位置编码的方便性,同时可以表示不同 token 之间的相对位置关系。[5] 不同于原始 Transformers 论文中,将 pos embedding 和 token embedding 进行相加,RoPE 是将位置编码和 query (或者 key) 进行相乘。具体如下:

Rotary Position Embedding

其中,左侧的矩阵 𝑅𝑚 表示位置第 𝑚 个位置的位置编码,右侧的向量 𝑞𝑖 表示对应位置的 query 向量。两者相乘,即可得到增加了位置信息的 query (或者 key)。由于 𝑅𝑚 的稀疏性,上述矩阵乘法可以等价于:

Rotary Position Embedding 的简化实现

其中 ⊗ 是逐位对应相乘,

RoPE的代码实现如下[6]:

# 代码增加了注释,可以看到和原始公式的对应关系。
class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):super().__init__()# 此处 inv_freq 对应公式中的 thetainv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)# 此处 freqs 对应公式中的 m * theta, t 对应公式中的 m,表示位置freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# 此处和原始公式不同,theta_0 和 theta_0 不再相邻# 而是分在向量的前半部分和后半部分emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 大部分情况下,直接从这里返回return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)def rotate_half(x):"""Rotates half the hidden dims of the input."""# 此次和原始推导中不同,正负号不是间隔的,而是分前半部分和后半部分。但对于结果没有影响x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids):# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]# 对应上图中 RoPE 的简化计算q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed

四、高效实现

加速训练:

  • 使用了xformers库。
  • 减少了activation checkpointing 中,重新计算 activation 的计算量。手动实现 transformer 层的反向传递函数,保存了计算成本高的 activations,例如线性层的输出。
  • 通过使用 model parallelism 和 sequence parallelism 来减少显存的使用量。
  • 尽可能地将 activations 的计算和GPU之间的通讯进行并行。

加速效果:

  • 65B的模型,在2048个80G的A100 GPU上,可以达到380 tokens/sec/GPU的速度。训练1.4T tokens需要21天。

五、主要结果与结论

Massive Multitask LanguageUnderstanding

LLaMA-13B 优于 GPT-3,尽管只有1/10大小。 LLaMA-65B 是可以与 Chinchilla-70B 和 PaLM-540B 这种最佳的LLM相竞争的模型。经过微调之后,LLaMA的效果有显著的提升。

未来打算发布在更大的语料上预训练上的更大的模型,因为随着数据和模型的增大,可以看到 performance 的稳定提升。

优化器

LLaMA使用了AdamW优化器进行训练,优化器的超参数为 =0.9, =0.95

(关于AdamW这个大模型训练的优化器,可参考当前训练神经网络最快的方式:AdamW优化算法+超级收敛 | 机器之心[6])

下表为LLaMA不同参数大小模型的具体设置:

表2: LLaMA不同参数大小模型的具体设置

参数维度(dim)head个数layer层数学习率batch sizetoken数量
6.7B409632323.0e−44M1.0T
13.0B512040403.0e−44M1.0T
32.5B665652601.5e−44M1.4T
65.2B819264801.5e−44M1.4T

训练结果

如下图所示,7B、13B、33B和65模型的训练损失均呈下降趋势,且在所有token上训练完后,loss仍没有收敛的趋势。因此,在此时,增加训练的token数量,仍然可以使模型继续学习。

(LLaMA2就是在此结论的基础上,使用了更多的token进行训练)

020f808566e73586ea9239922bce9824.png

高效部署

研究团队做了一些优化来提高模型的训练速度:

  1. 因果多头注意的有效实现:使用因果多头注意的有效实现来减少内存使用和运行时间。该实现可在xformers库中获得,其灵感来自于固定激活值显存优化和FlashAttention。这是通过不存储注意力权重和不计算由于语言建模任务的因果性质而被掩盖的key/query分数来实现的。

  2. 激活重计算:为了进一步提高训练效率,通过检查点减少了在向后传递过程中重新计算的激活量。更准确地说,节省了计算成本高的激活,比如线性层的输出。这是通过手动实现transformer层的backward函数来实现的,而不是依赖于PyTorch的autograd。

  3. 模型并行和序列并行:为了从这种优化中充分受益,需要通过使用模型和序列并行来减少模型的内存使用。此外,还尽可能地重叠激活的计算和gpu之间通过网络的通信。

笔者NOTE:LLM的高效训练是LLM工程实现的基础,对于这部分,各位小伙伴还是需要深入地了解一下各种并行策略、因果多头注意的有效实现、 激活重计算、混合精度训练。

参考

  1. ^Training Compute-Optimal Large Language Models https://arxiv.org/abs/2203.15556
  2. ^Training Compute-Optimal Large Language Models https://arxiv.org/abs/2203.15556
  3. ^Root Mean Square Layer Normalization https://arxiv.org/pdf/1910.07467.pdf
  4. ^GLU Variants Improve Transformer https://arxiv.org/pdf/2002.05202.pdf
  5. ^Transformer升级之路:2、博采众长的旋转式位置编码 Transformer升级之路:2、博采众长的旋转式位置编码 - 科学空间|Scientific Spaces
  6. ^transformers/src/transformers/models/llama/modeling_llama.py at main · huggingface/transformers · GitHub

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/6558.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解 Java 并发:AbstractQueuedSynchronizer 源码分析

序言 在多线程编程中,同步机制是保障线程安全和协调线程之间操作顺序的重要手段。AQS 作为 Java 中同步机制的基础框架,为开发者提供了一个灵活且高效的同步工具。本文将通过对 AQS 源码的分析,解读 AQS 的核心实现原理,并深入探…

使用FastGPT+OneAPI在本地使用Llama3

FastGPT 是一个基于 LLM 大语言模型的知识库问答系统,提供开箱即用的数据处理、模型调用等能力。同时可以通过 Flow 可视化进行工作流编排,从而实现复杂的问答场景!他的重要特点就是工作流编排。 工作流编排:基于 Flow 模块的工作…

微信小程序 uniapp家庭食谱菜谱食材网上商城系统小程序ko137

随着生活节奏的不断加快,越来越多的人因为工作忙而没有时间自己出去订购喜欢的菜品。随着Internet的飞速发展,网络已经成为我们日常生活中必不可少的部分,越来越多的人也接受了电子商务这种快捷、方便的交易方式。网上订餐其独有的便捷性和直…

口才训练:如何用声音和语言展现自我魅力

口才训练:如何用声音和语言展现自我魅力 这里有一篇1270字左右的文章,主要介绍如何用声音和语言来展现自我魅力: 口才训练是提升个人魅力的重要途径之一。魅力不仅取决于外表,更重要的是声音和语言的运用。良好的语言表达能力可以…

Spring扩展点(一)Bean生命周期扩展点

Bean生命周期扩展点 影响多个Bean的实例化InstantiationAwareBeanPostProcessorBeanPostProcessor 影响单个Bean的实例化纯粹的生命周期回调函数InitializingBean(BeanPostProcessor 的before和after之间调用)DisposableBean Aware接口在生命周期实例化过…

二叉树的实现(详解,数据结构)

目录 一,二叉树需要实现的功能 二,下面是各功能详解 0.思想: 1.创建二叉树结点: 2.通过前序遍历的数组"ABD##E#H##CF##G##"构建二叉树 3.二叉树销毁: 4.前序遍历: 5.中序遍历:…

RabbitMQ之顺序消费

什么是顺序消费 例如:业务上产生者发送三条消息, 分别是对同一条数据的增加、修改、删除操作, 如果没有保证顺序消费,执行顺序可能变成删除、修改、增加,这就乱了。 如何保证顺序性 一般我们讨论如何保证消息的顺序性&…

3GPP官网下载协议步骤

1.打开官网 https://www.3gpp.org/ 2.点击 3.在界面选择要找的series,跳转到查找界面 以V2X通信协议为例,论文中通常会看到许多应用: [7] “Study on evaluation methodology of new Vehicle-to-Everything (V2X) use cases for LTE and NR…

3.2Java全栈开发前端+后端(全栈工程师进阶之路)-前端框架VUE3框架-企业级应用- Vuex

Vuex简介 Vuex概述 Vuex是一个专门为Vue.js应用程序开发的状态管理模式, 它采用集中式存储管理所有组件的公共状态, 并以相应的规 则保证状态以一种可预测的方式发生变化. 试想这样的场景, 比如一个Vue的根实例下面有一个根组件名为App.vue, 它下面有两个子组件A.vue和B.vu…

022、Python+fastapi,第一个Python项目走向第22步:ubuntu 24.04 docker 安装mysql8集群、redis集群(三)

这次来安装mysql8了,以前安装不是docker安装,这个我也是第一次,人人都有第一次嚒 前言 前面的redis安装还是花了点时间的,主要是网上教程,各有各的好,大家千万别取其长处,个人觉得这个环境影响…

ASP.NET网上车辆档案管理系统

摘 要 本文采用基于Web的Asp.net技术,并与sql server 2000数据库相结合,研发了一套车辆档案管理系统。该系统扩展性好,易于维护。简化了车辆档案设计流程,去除了冗余信息。汽车销售企业可以通过本系统完成整个销售及售后所有档案…

python爬虫实战

import requests import json yesinput(输入页数:) yesint(yes)headers {"accept": "application/json, text/plain, */*","accept-language": "zh-CN,zh;q0.9","content-type": "application/json",…

一对一WebRTC视频通话系列(三)——leave和peer-leave信令实现

本篇博客主要分为两部分,第一部分为leave信令的实现,即当有客户端离开房间后,服务端和其他在房间内的客户需知晓。第二部分为媒体协商和网络协商相关API。 本系列博客主要记录一对一WebRTC视频通话实现过程中的一些重点,代码全部进…

渗透之sql盲注(时间/boolean盲注)

sql盲注:sql盲注意思是我们并不能在web页面中看到具体的信息,我们只能通过输入的语句的真假来判断。从而拿到我们想要的信息。 我们通常使用ascii值来进行盲注。 目录 手动注入: 时间盲注: 布尔盲注: python脚本注…

【Java】基本程序设计结构(一)

前言:现在,假定已经成功安装了JDK,并且能够运行上篇示例程序。本篇将开始介绍Java程序中的基本设计结构,其中包括:一个简单的Java应用,注释,数据类型,变量与常量,运算符&…

【深度学习基础(3)】初识神经网络之深度学习hello world

文章目录 一. 训练Keras中的MNIST数据集二. 工作流程1. 构建神经网络2. 准备图像数据3. 训练模型4. 利用模型进行预测5. (新数据上)评估模型精度 本节将首先给出一个神经网络示例,引出如下概念。了解完本节后,可以对神经网络在代码上的实现有一个整体的了…

【架构系列】RabbitMQ应用场景及在实际项目中如何搭建可靠的RabbitMQ架构体系

作者:后端小肥肠 创作不易,未经允许禁止转载。 1. 前言 RabbitMQ,作为一款高性能、可靠的消息队列软件,已经成为许多企业和开发团队的首选之一。它的灵活性和可扩展性使得它适用于各种应用场景,从简单的任务队列到复杂的分布式系统…

算法设计与分析——期末1h

目录 第一章 算法的定义 算法的三要素 算法的基本性质 算法的时间复杂度数量级: 第二章 兔子繁殖问题(递推法) 猴子吃桃问题(递推法) 穿越沙漠问题(递推法(倒推)) 百钱百…

解决Maven本地仓库存在依赖包还需要远程下载的问题

背景 公司有自己maven私服,正在在私服可以使用的情况,打包是没问题的。但是这次是由于公司大楼整体因电路检修而停电,所有服务器关机,包括maven私服服务器。然后当天确有一个包需要打,这个时候发现死活打不了&#xf…

线性数据结构-手写链表-LinkList

为什么需要手写实现数据结构? 其实技术的本身就是基础的积累和搭建的过程,基础扎实 地基平稳 万丈高楼才会久战不衰,做技术能一通百,百通千就不怕有再难得技术了。 一:链表的分类 主要有单向,双向和循环链表…