LLaMA开源大模型源码分析!

 Datawhale干货 

作者:宋志学,Datawhale成员

花了一晚上照着transformers仓库的LLaMA源码,把张量并行和梯度保存的代码删掉,只留下模型基础结构,梳理了一遍LLaMA的模型结构。

今年四月份的时候,我第一次接触深度学习,也是今年第一次接触Datawhale,在Datawhale和小伙伴一起学习、讨论了大半年,不知不觉已经可以做到看源码的程度了。

Datawhale才是一个没有围墙的大学,在这里无论你有什么想法💡,只要你愿意前进,总会有小伙伴和你一起。

博客地址:

https://flowus.cn/kmno4/share/527055be-464f-4f0f-98c5-8b8f72a1fc2e

LLaMA-Model

在transformers仓库中可以看到llama的源码,首先是LlamaModel类,继承自PreTrainedModel,这个类是所有模型的基类,包含了一些通用的方法,比如保存模型、加载模型、初始化权重等。

继承关系为:LlamaModel-> LlamaPreTrainedModel-> PreTrainedModel

LlamaConfig

LlamaConfig 中主要是定义一些参数,比如vocab_size、hidden_size、num_hidden_layers、num_attention_heads等。所有的参数有默认值,可以直接创建cofing就能用。

config = LlamaConfig()

LlamaModel

6783830017ed6f9a29869202cd8218ff.jpeg

LlamaModel 初始化

  • 设置了模型的两个属性:padding_idx(用于指定填充标记的索引),vocab_size(词汇表的大小)

  • 初始化了模型的嵌入层、解码器层、归一化层

  • 嵌入层(nn.Embedding):模型使用嵌入层将输入的标记映射成密集的向量表示。

  • 解码器层(nn.ModuleList()):模型包含多个解码器层,这些层都是由 LlamaDecoderLayer 定义

  • 归一化层 LlamaRMSNorm:归一化层使用的是 Root Mean Square Layer Normalization(RMS Layer Norm)

  • 设置了是否使用 gradient_checkpoint 主要是用来节省显存

  • 调用 post_init() 完成一些初始化和准备检查的代码

def __init__(self, config: LlamaConfig):super().__init__(config)self.padding_idx = config.pad_token_idself.vocab_size = config.vocab_size# embedding 层self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)# 中间的一堆 decoderlayers 层self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])self._use_sdpa = config._attn_implementation == "sdpa"self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)self.gradient_checkpointing = False# Initialize weights and apply final processingself.post_init()

可以看一下 post_init() 的代码,主要是初始化权重和gradient_checkpointing相关的一些事情。该方法在PreTrainedModel基类中,transformers中所有模型基本都继承这个类。

def post_init(self):"""A method executed at the end of each Transformer model initialization, to execute code that needs the model'smodules properly initialized (such as weight initialization)."""self.init_weights()self._backward_compatibility_gradient_checkpointing()

LlamaModel forward

forward 部分的代码有点长,但其实大部分都是张量并行或者是节省显存相关的代码,对于理解模型结构来说可以直接忽略。

首先进来就是把 inputs_ids 进行向量化,然后拿到 hidden_states 。然后是存起来所有的hidden_states 进入 decoder_layer 再拿一个 hidden_states,作为下一轮 decoder_layerhidden_states 输入,最后给 hidden_states norm一下。如下代码所示:

inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embedsfor decoder_layer in self.layers:# 存起来所有的 hidden_statesif output_hidden_states:all_hidden_states += (hidden_states,)# 这里是 decoder_layer 的 forwardlayer_outputs = decoder_layer(hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,)# 再拿一个 hidden_states,作为下一轮 decoder_layer 的 hidden_states 输入hidden_states = layer_outputs[0]hidden_states = self.norm(hidden_states)

最后就是以 BaseModelOutputWithPast 的形式输出。ok,接下来继续看decoder_layer中的其他代码。

LlamaDecoderLayer

Embedding层不用多说,用的就是torch中的nn.Embedding。那就直接来看DecoderLayer。

8e444bd665786abd4f4a471a8a72b244.png

DecoderLayers 初始化

先来看初始化。

  • hidden_size : 也就是在上面说的输入输出。

  • self_attn : 别看它写这么多啊,其实就是选一下用什么 attention 。看见大写字母不要怕,直接点进去看看怎么个事!

    LLAMA_ATTENTION_CLASSES = {"eager": LlamaAttention,"flash_attention_2": LlamaFlashAttention2,"sdpa": LlamaSdpaAttention,
    }
  • mlp : 一个全连接层 LlamaMLP 这个待会后面再说,输入输出都是 hidden_size 大小。

  • input_layernorm : LlamaRMSNorm 层,输入时候的norm

  • post_attention_layernorm : 丢入 mlp 之前的操作。

class LlamaDecoderLayer(nn.Module):def __init__(self, config: LlamaConfig, layer_idx: int):super().__init__()self.hidden_size = config.hidden_sizeself.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)self.mlp = LlamaMLP(config)self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\

DecoderLayers forward

首先复制一份 hidden_statesresidual。然后 hidden_states 进入 input_layernorm 进行norm。然后进入 self_attn 进行 attention 操作,拿到 hidden_statesself_attn_weightspresent_key_value。然后 hidden_statesresidual 相加,得到 hidden_states

然后 hidden_states 进入 post_attention_layernorm 进行norm。最后 hidden_states 进入 mlp 进行全连接操作,拿到 hidden_states。然后 hidden_statesresidual 相加,得到 hidden_states。最后输出 hidden_states

residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,**kwargs,
)
hidden_states = residual + hidden_states# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_statesoutputs = (hidden_states,)if output_attentions:outputs += (self_attn_weights,)if use_cache:outputs += (present_key_value,)return outputs

Llama Attention

31aaddd17134c8360db0117ba21d30b6.png

看代码首先映入眼帘的就是  Attention Is All You Need  好好好,很有精神!那我们接着往下看。

先来看 init 部分叭。

  • layer_idx : 这个就是第几个 DecoderLayers 层。不用关心。

  • attention_dropout : 用于dropout的概率。

  • hidden_size : 输入输出大小。

  • num_attention_heads : 多头注意力的头数。

  • head_dim : 多头注意力的维度 self.hidden_size // self.num_heads,和transformers中的一样。

  • num_key_value_heads : 用于key和value的头数。

其他的参数都在 LlamaConfig 中有默认值,可以直接使用,也可以直接去LlamaConfig的源码中看具体的解释,这里就不再多说。

再往下就是 q_projk_projv_projo_proj 四个矩阵(全连接层),耳熟能详了。

class LlamaAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):super().__init__()self.config = configself.layer_idx = layer_idxif layer_idx is None:logger.warning_once(f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will ""to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` ""when creating this class.")self.attention_dropout = config.attention_dropoutself.hidden_size = config.hidden_sizeself.num_heads = config.num_attention_headsself.head_dim = self.hidden_size // self.num_headsself.num_key_value_heads = config.num_key_value_headsself.num_key_value_groups = self.num_heads // self.num_key_value_headsself.max_position_embeddings = config.max_position_embeddingsself.rope_theta = config.rope_thetaself.is_causal = Trueif (self.head_dim * self.num_heads) != self.hidden_size:raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {self.num_heads}).")self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)self._init_rope()

LlamaAttention forward

重头戏来了,attention forward 部分。

注意:其中有关于张量并行或者显存节省的部分我就直接省略了,直接看主要代码。这个笔记主要是分析llama的模型结构,并不讨论如何节省显存。

首先拿到 hidden_statesbatch_sizeseq_len 。然后把 hidden_states 丢入 q_projk_projv_proj 三个矩阵(全连接层),拿到 query_stateskey_statesvalue_states 。然后把 query_stateskey_statesvalue_states reshape 为下一步计算做准备。

将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果

key_statesvalue_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化

然后 attn_weights 加上 attention_mask,再 softmaxdropout。然后 attn_weightsvalue_states 相乘,把 attn_output reshape 为下一步计算做准备,最后把 attn_output 丢入 o_proj ,然后return就行了。

好了,至此。我觉得llama最激动人心的地方已经结束了。

# 获取 batch_size 和 seq_len
bsz, q_len, _ = hidden_states.size()# 把 hidden_states 丢入 q_proj、k_proj、v_proj
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)# 把 q_proj、k_proj、v_proj 的输出 reshape 为下一步计算做准备
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)# 将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)# 首先,它将key_states和value_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)# 然后 attn_weights 加上 attention_mask
attn_weights = attn_weights + attention_mask# softmax + dropout
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)# 然后 attn_weights 和 value_states 相乘
attn_output = torch.matmul(attn_weights, value_states)# 然后把 attn_output reshape 为下一步计算做准备
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)# 最后把 attn_output 丢入 o_proj
attn_output = self.o_proj(attn_output)# 返回 attn_output、attn_weights、present_key_value
return attn_output, attn_weights, past_key_value

LlamaMLP

c1cd4cf6f3c0e2a88c2f2b536e2fd10f.png

看完 attention 再看 MLP ,突然就觉得好简单了,哈哈哈。这部分代码比较少,就直接放到一起了。

x进来之后先进去up_proj和gate_proj,gate_proj进行激活,然后这俩再乘起来,丢进 down_proj。那直接放个图叭,这个过程有点简单了。

class LlamaMLP(nn.Module):def __init__(self, config):super().__init__()# 这俩不必多说self.config = configself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_size# 三个全连接层self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_proj

LlamaRMSNorm

RMSNorm函数可以用以下数学公式表示:

其中:

  • 是层的输入。

  • 代表层的权重。

  • 是权重的数量。

  • 是一个小常数,用于数值稳定性(以避免除以零的情况)。

这种归一化有助于通过确保权重的规模不会变得过大或过小来稳定学习过程,这在具有许多层的深度学习模型中特别有用。

class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype)

参考:https://space.bilibili.com/45156039

8c2428ec014740c4d5aa5ee991a3cb14.png

干货学习,三连

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/240366.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenAI 疑似正在进行 GPT-4.5 灰度测试!

‍ 大家好,我是二狗。 今天,有网友爆料OpenAI疑似正在进行GPT-4.5灰度测试! 当网友询问ChatGPT API调用查询模型的确切名称是什么时? ChatGPT的回答竟然是 gpt-4.5-turbo。 也有网友测试之后发现仍然是GPT-4模型。 这是有网友指…

自动化测试架构设计必会知识点——对核心业务进行封装复用(附Java源码)

随着UI自动化测试工具可选性越来越多,工具也越来越稳定,前几年关于自动化测试架构设计的概念逐渐淡化,但是做自动化测试最重要的两点—— PO设计模式和核心业务的封装复用大家还是必须掌握的,前面的文章我已经介绍了什么是PO设计模…

交叉熵损失(Cross Entropy Loss)学习笔记

在分类任务中,我们通常使用交叉熵作为损失函数,首先给出交叉熵的计算公式: 二分类中: L 1 N ∑ i L i 1 N ∑ i − [ y i l o g ( p i ) ( 1 − y i ) ⋅ l o g ( 1 − p i ) ] \mathcal{L}\frac1{N}\sum_{i}L_i\frac1{N}\sum…

基于 Sentry 的前端监控系统搭建(Linux)

一、前言 随着技术这几年的发展与沉淀,线上数据指标监控也变得尤为重要,研发人员和运营人员需要对线上的产品指标有所感知,同时风险也需要及时暴露,很多公司开始自建监控系统,但对于一些定制化要求不是特别高的团队&a…

【PHP】Yii2 使用validate规则验证

目录 提交验证 声明规则 特殊验证 一个特殊验证的示例 内联验证器 一个完整示例 参考文档 提交验证 根据经验,您永远不应该相信从最终用户那里收到的数据,并且应该在很好地使用这些数据之前对其进行验证。 给定一个model模型,用户输入填…

网络爬虫之Ajax动态数据采集

动态数据采集 规则 有时候我们在用 requests 抓取页面的时候,得到的结果可能和在浏览器中看到的不一样,在浏览器中可以看到正常显示的页面教据,但是使用 requests 得到的结果并没有,这是因为requests 获取的都是原始的 HTML 文档…

前端跨页面交互信息或传递信息都有这么几种方式,总有一个满足你的使用场景

同源页面通信&#xff1a; 方法&#xff1a;Broadcast Channel 【广播】 这个可以查看我另一篇文章有使用案例 同一来源的不同文档&#xff08;在不同的窗口、选项卡、框架或 iframe 中&#xff09;之间进行通信 // page1<!DOCTYPE html> <html lang"en"…

(1)(1.9) MSP (version 4.2)

文章目录 前言 1 协议概述 2 配置 3 参数说明 前言 ArduPilot 支持 MSP 协议&#xff0c;可通过任何串行端口进行遥测、OSD 和传感器。这样&#xff0c;ArduPilot 就能将遥测数据发送到 MSP 兼容设备&#xff08;如大疆护目镜&#xff09;&#xff0c;用于屏幕显示&#x…

(C)一些题15

1.下列关于 C 语言程序结构的说法中&#xff0c;不正确的是&#xff08;D) A &#xff0e;一个程序由一个或多个源程序文件组成 B &#xff0e;函数是 C 程序的主要组成部分 C &#xff0e;程序总是从 main 函数开始执行的 D . C 语言本身提供了许多输入输出语句 解析&…

【Java基础】01- 基础概念

注释和关键字 注释 定义&#xff1a;代码中需要一些解释说明性的文字&#xff0c;通常称为注释。 单行注释&#xff1a; // 注释信息多行注释&#xff1a;/* 注释信息 */文档注释&#xff1a;/** 注释信息 */ 文档注释可以利用Java中自带的DOC工具&#xff0c;自动生成相关代…

银河麒麟v10 安装mysql 8.35

银河麒麟v10 安装mysql 8.35 1、卸载mariadb2、下载Mysql安装包3、安装Mysql 8.353.1、安装依赖包3.2、安装Mysql3.3、安装后配置 1、卸载mariadb 由于银河麒麟v10系统默认安装了mariadb 会与Mysql相冲突&#xff0c;因此首先需要卸载系统自带的mariadb 查看系统上默认安装的M…

MyBatis动态SQL中if,where,set,trim四种标签的使用和联系

目录 MyBatis动态SQL中if&#xff0c;where&#xff0c;set&#xff0c;trim四种标签的使用和联系1、先介绍trim标签以下是trim标签中涉及到的属性&#xff1a; 2、使用trim标签或where标签去除多余的and关键字3、使用trim标签或set标签去除多余的逗号 MyBatis动态SQL中if&…

前端常用的开发工具

前端常用的开发工具&#x1f516; 文章目录 前端常用的开发工具&#x1f516;1. Snipaste--截图工具2. ScreenToGif--gif图片录制3. Typora--Markdown编辑器4. notepad--文本代码编辑器5. uTools--多功能工具6. EV录屏--录屏软件7. Xmind--思维导图8. Apifox -- 接口调试9. Tor…

【大数据】NiFi 中的 Controller Service

NiFi 中的 Controller Service 1.Service 简介1.1 Controller Service 的配置1.1.1 SETTING 基础属性1.1.2 PROPERTIES 使用属性1.1.3 COMMENT 页签 1.2 Service 的使用范围 2.全局参数配置3.DBCPConnectionPool 的使用样例4.在 ExcuseGroovyScript 组件中使用 Service 1.Servi…

记一次 Nginx 调参的踩坑经历

最近在基于SSE&#xff08;Server Sent Events&#xff09;做服务端单向推送服务&#xff0c;本地开发时一切顺利&#xff0c;但是在部署到预发环境时就碰到1个很诡异的问题&#xff0c;这里需要简单介绍下我们的整体架构&#xff1a; 整体架构 可以看到所有的请求都会先到统一…

2024 年 22 款顶级免费数据恢复软件比较 [Windows 和 Mac]

适用于 Windows 和 Mac 用户的最佳数据恢复软件下载列表和比较&#xff0c;可快速恢复丢失的数据、已删除的文件、照片或格式化的分区数据&#xff1a; 数据恢复软件是一种从任何存储介质恢复丢失文件的应用程序。它可以恢复由于病毒攻击、硬盘故障或任何其他原因而意外删除或…

NIO的实战教程(简单且高效)

1. 参考 建议按顺序阅读以下三篇文章 为什么NIO被称为同步非阻塞&#xff1f; Java IO 与 NIO&#xff1a;高效的输入输出操作探究 【Java.NIO】Selector&#xff0c;及SelectionKey 2. 实战 我们将模拟一个简单的HTTP服务器&#xff0c;它将响应客户端请求并返回一个固定的…

Wails中js调用go函数(1种go写法,2种js调用方法)

官方js调用go方法文档&#xff1a;https://wails.io/zh-Hans/docs/howdoesitwork a&#xff09;在app.go文件里写一个要js调用的go函数&#xff1a; func (a *App) JSCallGo(data1 string) string { return “test” } b&#xff09;运行 wails dev 命令&#xff0c…

Maven核心概念

1 Maven工程的GAVP Maven 中的 GAVP 是指 GroupId、ArtifactId、Version、Packaging 等四个属性的缩写&#xff0c;其中前三个是必要的&#xff0c;而 Packaging 属性为可选项。 这四个属性主要为每个项目在maven仓库中做一个标识&#xff0c;方便项目之间相互引用。 GAV G 即…

54.0/CSS 样式属性(详细版)

目录 54.1 文本属性 54.1.1 设置字体——font-family 54.1.2 设置字号(font-size)和字的颜色(color) 54.1.3 设置字体加粗——font-weight 54.1.4 添加文字修饰——text-decoration 54.1.5 设置文本排列方式——text-align 54.1.6 设置段落首行缩进——text-indent …