LLaMA开源大模型源码分析!

 Datawhale干货 

作者:宋志学,Datawhale成员

花了一晚上照着transformers仓库的LLaMA源码,把张量并行和梯度保存的代码删掉,只留下模型基础结构,梳理了一遍LLaMA的模型结构。

今年四月份的时候,我第一次接触深度学习,也是今年第一次接触Datawhale,在Datawhale和小伙伴一起学习、讨论了大半年,不知不觉已经可以做到看源码的程度了。

Datawhale才是一个没有围墙的大学,在这里无论你有什么想法💡,只要你愿意前进,总会有小伙伴和你一起。

博客地址:

https://flowus.cn/kmno4/share/527055be-464f-4f0f-98c5-8b8f72a1fc2e

LLaMA-Model

在transformers仓库中可以看到llama的源码,首先是LlamaModel类,继承自PreTrainedModel,这个类是所有模型的基类,包含了一些通用的方法,比如保存模型、加载模型、初始化权重等。

继承关系为:LlamaModel-> LlamaPreTrainedModel-> PreTrainedModel

LlamaConfig

LlamaConfig 中主要是定义一些参数,比如vocab_size、hidden_size、num_hidden_layers、num_attention_heads等。所有的参数有默认值,可以直接创建cofing就能用。

config = LlamaConfig()

LlamaModel

6783830017ed6f9a29869202cd8218ff.jpeg

LlamaModel 初始化

  • 设置了模型的两个属性:padding_idx(用于指定填充标记的索引),vocab_size(词汇表的大小)

  • 初始化了模型的嵌入层、解码器层、归一化层

  • 嵌入层(nn.Embedding):模型使用嵌入层将输入的标记映射成密集的向量表示。

  • 解码器层(nn.ModuleList()):模型包含多个解码器层,这些层都是由 LlamaDecoderLayer 定义

  • 归一化层 LlamaRMSNorm:归一化层使用的是 Root Mean Square Layer Normalization(RMS Layer Norm)

  • 设置了是否使用 gradient_checkpoint 主要是用来节省显存

  • 调用 post_init() 完成一些初始化和准备检查的代码

def __init__(self, config: LlamaConfig):super().__init__(config)self.padding_idx = config.pad_token_idself.vocab_size = config.vocab_size# embedding 层self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)# 中间的一堆 decoderlayers 层self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])self._use_sdpa = config._attn_implementation == "sdpa"self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)self.gradient_checkpointing = False# Initialize weights and apply final processingself.post_init()

可以看一下 post_init() 的代码,主要是初始化权重和gradient_checkpointing相关的一些事情。该方法在PreTrainedModel基类中,transformers中所有模型基本都继承这个类。

def post_init(self):"""A method executed at the end of each Transformer model initialization, to execute code that needs the model'smodules properly initialized (such as weight initialization)."""self.init_weights()self._backward_compatibility_gradient_checkpointing()

LlamaModel forward

forward 部分的代码有点长,但其实大部分都是张量并行或者是节省显存相关的代码,对于理解模型结构来说可以直接忽略。

首先进来就是把 inputs_ids 进行向量化,然后拿到 hidden_states 。然后是存起来所有的hidden_states 进入 decoder_layer 再拿一个 hidden_states,作为下一轮 decoder_layerhidden_states 输入,最后给 hidden_states norm一下。如下代码所示:

inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embedsfor decoder_layer in self.layers:# 存起来所有的 hidden_statesif output_hidden_states:all_hidden_states += (hidden_states,)# 这里是 decoder_layer 的 forwardlayer_outputs = decoder_layer(hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,)# 再拿一个 hidden_states,作为下一轮 decoder_layer 的 hidden_states 输入hidden_states = layer_outputs[0]hidden_states = self.norm(hidden_states)

最后就是以 BaseModelOutputWithPast 的形式输出。ok,接下来继续看decoder_layer中的其他代码。

LlamaDecoderLayer

Embedding层不用多说,用的就是torch中的nn.Embedding。那就直接来看DecoderLayer。

8e444bd665786abd4f4a471a8a72b244.png

DecoderLayers 初始化

先来看初始化。

  • hidden_size : 也就是在上面说的输入输出。

  • self_attn : 别看它写这么多啊,其实就是选一下用什么 attention 。看见大写字母不要怕,直接点进去看看怎么个事!

    LLAMA_ATTENTION_CLASSES = {"eager": LlamaAttention,"flash_attention_2": LlamaFlashAttention2,"sdpa": LlamaSdpaAttention,
    }
  • mlp : 一个全连接层 LlamaMLP 这个待会后面再说,输入输出都是 hidden_size 大小。

  • input_layernorm : LlamaRMSNorm 层,输入时候的norm

  • post_attention_layernorm : 丢入 mlp 之前的操作。

class LlamaDecoderLayer(nn.Module):def __init__(self, config: LlamaConfig, layer_idx: int):super().__init__()self.hidden_size = config.hidden_sizeself.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)self.mlp = LlamaMLP(config)self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\

DecoderLayers forward

首先复制一份 hidden_statesresidual。然后 hidden_states 进入 input_layernorm 进行norm。然后进入 self_attn 进行 attention 操作,拿到 hidden_statesself_attn_weightspresent_key_value。然后 hidden_statesresidual 相加,得到 hidden_states

然后 hidden_states 进入 post_attention_layernorm 进行norm。最后 hidden_states 进入 mlp 进行全连接操作,拿到 hidden_states。然后 hidden_statesresidual 相加,得到 hidden_states。最后输出 hidden_states

residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,**kwargs,
)
hidden_states = residual + hidden_states# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_statesoutputs = (hidden_states,)if output_attentions:outputs += (self_attn_weights,)if use_cache:outputs += (present_key_value,)return outputs

Llama Attention

31aaddd17134c8360db0117ba21d30b6.png

看代码首先映入眼帘的就是  Attention Is All You Need  好好好,很有精神!那我们接着往下看。

先来看 init 部分叭。

  • layer_idx : 这个就是第几个 DecoderLayers 层。不用关心。

  • attention_dropout : 用于dropout的概率。

  • hidden_size : 输入输出大小。

  • num_attention_heads : 多头注意力的头数。

  • head_dim : 多头注意力的维度 self.hidden_size // self.num_heads,和transformers中的一样。

  • num_key_value_heads : 用于key和value的头数。

其他的参数都在 LlamaConfig 中有默认值,可以直接使用,也可以直接去LlamaConfig的源码中看具体的解释,这里就不再多说。

再往下就是 q_projk_projv_projo_proj 四个矩阵(全连接层),耳熟能详了。

class LlamaAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):super().__init__()self.config = configself.layer_idx = layer_idxif layer_idx is None:logger.warning_once(f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will ""to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` ""when creating this class.")self.attention_dropout = config.attention_dropoutself.hidden_size = config.hidden_sizeself.num_heads = config.num_attention_headsself.head_dim = self.hidden_size // self.num_headsself.num_key_value_heads = config.num_key_value_headsself.num_key_value_groups = self.num_heads // self.num_key_value_headsself.max_position_embeddings = config.max_position_embeddingsself.rope_theta = config.rope_thetaself.is_causal = Trueif (self.head_dim * self.num_heads) != self.hidden_size:raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {self.num_heads}).")self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)self._init_rope()

LlamaAttention forward

重头戏来了,attention forward 部分。

注意:其中有关于张量并行或者显存节省的部分我就直接省略了,直接看主要代码。这个笔记主要是分析llama的模型结构,并不讨论如何节省显存。

首先拿到 hidden_statesbatch_sizeseq_len 。然后把 hidden_states 丢入 q_projk_projv_proj 三个矩阵(全连接层),拿到 query_stateskey_statesvalue_states 。然后把 query_stateskey_statesvalue_states reshape 为下一步计算做准备。

将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果

key_statesvalue_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化

然后 attn_weights 加上 attention_mask,再 softmaxdropout。然后 attn_weightsvalue_states 相乘,把 attn_output reshape 为下一步计算做准备,最后把 attn_output 丢入 o_proj ,然后return就行了。

好了,至此。我觉得llama最激动人心的地方已经结束了。

# 获取 batch_size 和 seq_len
bsz, q_len, _ = hidden_states.size()# 把 hidden_states 丢入 q_proj、k_proj、v_proj
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)# 把 q_proj、k_proj、v_proj 的输出 reshape 为下一步计算做准备
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)# 将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)# 首先,它将key_states和value_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)# 然后 attn_weights 加上 attention_mask
attn_weights = attn_weights + attention_mask# softmax + dropout
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)# 然后 attn_weights 和 value_states 相乘
attn_output = torch.matmul(attn_weights, value_states)# 然后把 attn_output reshape 为下一步计算做准备
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)# 最后把 attn_output 丢入 o_proj
attn_output = self.o_proj(attn_output)# 返回 attn_output、attn_weights、present_key_value
return attn_output, attn_weights, past_key_value

LlamaMLP

c1cd4cf6f3c0e2a88c2f2b536e2fd10f.png

看完 attention 再看 MLP ,突然就觉得好简单了,哈哈哈。这部分代码比较少,就直接放到一起了。

x进来之后先进去up_proj和gate_proj,gate_proj进行激活,然后这俩再乘起来,丢进 down_proj。那直接放个图叭,这个过程有点简单了。

class LlamaMLP(nn.Module):def __init__(self, config):super().__init__()# 这俩不必多说self.config = configself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_size# 三个全连接层self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_proj

LlamaRMSNorm

RMSNorm函数可以用以下数学公式表示:

其中:

  • 是层的输入。

  • 代表层的权重。

  • 是权重的数量。

  • 是一个小常数,用于数值稳定性(以避免除以零的情况)。

这种归一化有助于通过确保权重的规模不会变得过大或过小来稳定学习过程,这在具有许多层的深度学习模型中特别有用。

class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype)

参考:https://space.bilibili.com/45156039

8c2428ec014740c4d5aa5ee991a3cb14.png

干货学习,三连

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/240366.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenAI 疑似正在进行 GPT-4.5 灰度测试!

‍ 大家好,我是二狗。 今天,有网友爆料OpenAI疑似正在进行GPT-4.5灰度测试! 当网友询问ChatGPT API调用查询模型的确切名称是什么时? ChatGPT的回答竟然是 gpt-4.5-turbo。 也有网友测试之后发现仍然是GPT-4模型。 这是有网友指…

自动化测试架构设计必会知识点——对核心业务进行封装复用(附Java源码)

随着UI自动化测试工具可选性越来越多,工具也越来越稳定,前几年关于自动化测试架构设计的概念逐渐淡化,但是做自动化测试最重要的两点—— PO设计模式和核心业务的封装复用大家还是必须掌握的,前面的文章我已经介绍了什么是PO设计模…

基于 Sentry 的前端监控系统搭建(Linux)

一、前言 随着技术这几年的发展与沉淀,线上数据指标监控也变得尤为重要,研发人员和运营人员需要对线上的产品指标有所感知,同时风险也需要及时暴露,很多公司开始自建监控系统,但对于一些定制化要求不是特别高的团队&a…

网络爬虫之Ajax动态数据采集

动态数据采集 规则 有时候我们在用 requests 抓取页面的时候,得到的结果可能和在浏览器中看到的不一样,在浏览器中可以看到正常显示的页面教据,但是使用 requests 得到的结果并没有,这是因为requests 获取的都是原始的 HTML 文档…

(1)(1.9) MSP (version 4.2)

文章目录 前言 1 协议概述 2 配置 3 参数说明 前言 ArduPilot 支持 MSP 协议,可通过任何串行端口进行遥测、OSD 和传感器。这样,ArduPilot 就能将遥测数据发送到 MSP 兼容设备(如大疆护目镜),用于屏幕显示&#x…

银河麒麟v10 安装mysql 8.35

银河麒麟v10 安装mysql 8.35 1、卸载mariadb2、下载Mysql安装包3、安装Mysql 8.353.1、安装依赖包3.2、安装Mysql3.3、安装后配置 1、卸载mariadb 由于银河麒麟v10系统默认安装了mariadb 会与Mysql相冲突,因此首先需要卸载系统自带的mariadb 查看系统上默认安装的M…

MyBatis动态SQL中if,where,set,trim四种标签的使用和联系

目录 MyBatis动态SQL中if,where,set,trim四种标签的使用和联系1、先介绍trim标签以下是trim标签中涉及到的属性: 2、使用trim标签或where标签去除多余的and关键字3、使用trim标签或set标签去除多余的逗号 MyBatis动态SQL中if&…

前端常用的开发工具

前端常用的开发工具🔖 文章目录 前端常用的开发工具🔖1. Snipaste--截图工具2. ScreenToGif--gif图片录制3. Typora--Markdown编辑器4. notepad--文本代码编辑器5. uTools--多功能工具6. EV录屏--录屏软件7. Xmind--思维导图8. Apifox -- 接口调试9. Tor…

【大数据】NiFi 中的 Controller Service

NiFi 中的 Controller Service 1.Service 简介1.1 Controller Service 的配置1.1.1 SETTING 基础属性1.1.2 PROPERTIES 使用属性1.1.3 COMMENT 页签 1.2 Service 的使用范围 2.全局参数配置3.DBCPConnectionPool 的使用样例4.在 ExcuseGroovyScript 组件中使用 Service 1.Servi…

记一次 Nginx 调参的踩坑经历

最近在基于SSE(Server Sent Events)做服务端单向推送服务,本地开发时一切顺利,但是在部署到预发环境时就碰到1个很诡异的问题,这里需要简单介绍下我们的整体架构: 整体架构 可以看到所有的请求都会先到统一…

2024 年 22 款顶级免费数据恢复软件比较 [Windows 和 Mac]

适用于 Windows 和 Mac 用户的最佳数据恢复软件下载列表和比较,可快速恢复丢失的数据、已删除的文件、照片或格式化的分区数据: 数据恢复软件是一种从任何存储介质恢复丢失文件的应用程序。它可以恢复由于病毒攻击、硬盘故障或任何其他原因而意外删除或…

NIO的实战教程(简单且高效)

1. 参考 建议按顺序阅读以下三篇文章 为什么NIO被称为同步非阻塞? Java IO 与 NIO:高效的输入输出操作探究 【Java.NIO】Selector,及SelectionKey 2. 实战 我们将模拟一个简单的HTTP服务器,它将响应客户端请求并返回一个固定的…

Maven核心概念

1 Maven工程的GAVP Maven 中的 GAVP 是指 GroupId、ArtifactId、Version、Packaging 等四个属性的缩写,其中前三个是必要的,而 Packaging 属性为可选项。 这四个属性主要为每个项目在maven仓库中做一个标识,方便项目之间相互引用。 GAV G 即…

桶装水送水小程序:提升服务质量的利器

随着移动互联网的发展,越来越多的消费者通过手机在线购物和订购商品。如果你是一名桶装水供应商,想要拓展线上业务,那么开发一个桶装水微信小程序将是一个明智的选择。本文将指导你从零开始开发一个桶装水微信小程序,让你轻松完成…

Coze在手,GPTsDALLE免费用

1. 关于Coze Coze 是一个应用程序编辑平台,旨在开发下一代人工智能聊天机器人。 你可以使用无代码创建各种类型的聊天机器人,并将其部署到各种社交平台和消息应用程序。 链接: Coze 2. Coze的特点 Coze有5个特点。下面由我来详细介绍一下!…

高级数据结构 <二叉搜索树>

本文已收录至《数据结构(C/C语言)》专栏! 作者:ARMCSKGT 目录 前言正文二叉搜索树的概念二叉搜索树的基本功能实现二叉搜索树的基本框架插入节点删除节点查找函数中序遍历函数析构函数和销毁函数(后序遍历销毁)拷贝构造和赋值重载(前序遍历创建)其他函数…

蓝牙物联网与嵌入式开发如何结合?

蓝牙物联网与嵌入式开发可以紧密结合,以实现更高效、更智能的物联网应用。以下是一些结合的方式: 嵌入式开发为蓝牙设备提供硬件基础设施和控制逻辑:嵌入式系统可以利用微处理器和各种外设组成的系统,为蓝牙设备提供硬件基础设施和…

基于ERC20代币协议实现的去中心化应用平台

文章目录 内容简介设计逻辑ERC20TokenLoanPlatform 合约事件结构体状态变量函数 Remix 运行实现部署相关智能合约存款和取款贷款和还款 源码地址 内容简介 使用 solidity 实现的基于 ERC20 代币协议的借贷款去中心化应用平台(极简版)。实现存款、取款、贷款、还款以及利息计算的…

爬虫API|批量抓取电商平台商品数据,支持高并发

随着互联网的快速发展,电商平台如雨后春笋般涌现,为消费者提供了丰富的购物选择。然而,对于许多商家和数据分析师来说,如何快速、准确地获取电商平台上的商品数据成为了一个难题。为了解决这个问题,我们开发了一个爬虫…

ModuleNotFoundError: No module named ‘tensorflow‘

直接运行pip install tensorflow安装成功之后,发现版本是tensorflow2.15.0 python的版本是3.9版本 导入包:import tensorflow 打包xxx.exe,调用之后提示错误 ModuleNotFoundError: No module named tensorflow 最后发现特定的python的版本对应特定的t…