Llama网络结构介绍

LLaMA现在已经是开源社区里炙手可热的模型了,但是原文中仅仅介绍了其和标准Transformer的差别,并没有一个全局的模型介绍。因此打算写篇文章,争取让读者不参考任何其他资料把LLaMA的模型搞懂。

结构

如图所示为LLaMA的示意图,由Attention和MLP层堆叠而成
在这里插入图片描述
LLaMA模型主要由Attention和MLP层堆叠而成,具有以下特点:
1、前置的RMSNorm:RMSNorm是一种归一化技术,用于稳定模型的训练过程,提高模型的收敛速度。
2、Q、K上的RoPE旋转式位置编码:位置编码用于捕捉序列中的位置信息,RoPE旋转式位置编码能够有效地处理长序列,提高模型的性能。
3、Causal mask:该机制保证每个位置只能看到前面的tokens,确保了模型的自回归性质。
4、使用了Group Query Attention:通过使用分组查询注意力(GQA),LLaMA能够在保持性能的同时,降低模型的计算复杂度,提高推理速度。
5、MLP表达式:down(up(x) * SILU(gate(x))),其中down, up, gate都是线性层
LLaMA各个不同大小的结构设置如下表所示。其中最大的65B的LLaMA用了2048张80GB的A100,batch size为4百万,训练一次需要21天。

Group Query Attention(V2 only)

自回归模型生成回答时,需要前面生成的KV缓存起来,来加速计算。多头注意力机制(MHA)需要的缓存量很大,Multi-Query Attention指出多个头之间可以共享KV对。Group Query Attention没有像MQA一样极端,将query分组,组内共享KV,效果接近MHA,速度上与MQA可比较。p.s. 这个技术falcon已经用上了,当时falcon说自己用的是multi query attention,因为当group=1时,GQA和MQA是等价的。falcon支持设置不同的G。
在这里插入图片描述

RMSNorm

这是在BERT、GPT等模型中广泛使用的LayerNorm:
在这里插入图片描述
RMSNorm(root mean square)发现LayerNorm的中心偏移没什么用(减去均值等操作)。将其去掉之后,效果几乎不变,但是速度提升了40%。最终公式为:
在这里插入图片描述
注意除了没有减均值,加偏置以外,分母上求的RMS而不是方差。

LLaMA在 Attention Layer和MLP的输入上使用了RMSNorm,相比在输出上使用,训练会更加稳定。

SwiGLU

LLaMA没有使用ReLU,而是使用了SwiGLU,有时也被称为SiLU。公式为:
,效果类似平滑版的ReLU:
在这里插入图片描述

RoPE

LLaMA使用了Rotary Position Embedding。对于Q的第m个位置向量q,通过以下方法注入位置编码:
在这里插入图片描述

class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000):super().__init__()theta = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))t = torch.arange(max_position_mbeddings)freqs = torch.einsum("i,j->ij", t, theta)emb = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", emb.cos())self.register_buffer("sin_cached", emb.sin())def forward(self, seq_len=None):return self.cos_cached[:, :, :seq_len, ...], self.sin_cached[:, :, :seq_len, ...]# 在LlamaAttention通过以下命令调用:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)

以下代码将q沿着最后一个维度劈成两半,将后一半乘-1,然后连接在第一半之前,就得到了上式第三项。

# 在接下来的apply_rotary_pos_emb函数里调用def rotate_half(x):x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)

最后通过以下代码得到结合了位置编码的Q,K(K和Q使用同样的方式进行位置编码)。

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):q_embed = (q * cos[position_ids]) + (rotate_half(q) * sin[position_ids])k_embed = (k * cos[position_ids]) + (rotate_half(k) * sin[position_ids])return q_embed, k_embed# 在LlamaAttention中通过以下命令调用:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

绝对位置编码的优点是计算速度快等,缺点是拓展长度比较麻烦,且绝对位置并没有什么实际意义。而相对位置编码对学习token之间的关系很有意义,比如距离的很远的两个token之间的关联大概率很小,使用相对位置编码往往能够获得更好的效果。此外拓展长度也更容易,因为不论context size多长,只需关注最长距离以内的输入即可。相对位置编码的缺点是没有绝对位置编码计算速度快。

当我们计算Attention时,RoPE可以变成相对位置编码。
在这里插入图片描述
从上面这个公式可以看出,q和k的attention依赖相对距离m-n。因此RoPE为q、k注入的绝对位置编码,计算得到的attention,却变成了相对位置编码。妙的很,我这里为了不参考其他文章就很容易搞懂LLaMA的结构,简化了很多东西,推荐大家看一看RoPE原作者苏剑林的博客了解更多信息。

本文只关注LLaMA缺失的模型结构方面的介绍,对于文章的翻译可以参考其他的文章,
例如:靳伟,LLaMA大模型是如何炼成的,
其他参考文章:https://zhuanlan.zhihu.com/p/636784644
原文:https://arxiv.org/pdf/2302.13971.pdf。
文中参考的代码是huggingface的transformers库实现的版本,并不是Meta官方的代码。
备注说明:受笔者水平限制,如果哪里讲的不对,或者不够清晰易懂,欢迎在评论区与我交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/827718.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

域名被污染了只能换域名吗?

域名污染是指域名的解析结果受到恶意干扰或篡改,使得用户在访问相关网站时出现异常。很多域名遭遇过污染的情况,但是并不知道是域名污染,具体来说,域名污染可能表现为以下情况:用户无法通过输入正确的域名访问到目标网…

Windos环境下配置免费SSL证书详细步骤

获取免费证书 配置本机模拟域名 打开如下目录,hosts文件 C:\Windows\System32\drivers\etc 添加如下配置并保存 127.0.0.1 im.test.com下载安装 OpenSSL 下载链接 进入bin目录, 打开cmd窗口 执行如下命令,生成RSA私钥 ## 使用des3…

大型集团企业 怎么实现多区域文件交换?

很多大型集团企业,都会在全国各地,甚至海外,都设立分支机构,还有银行、邮政这类机构,都会在全国各地设立多个支行和网点,所以在日常经营过程中,都会存在多区域文件交换的场景。 大型集团企业在进…

JVM垃圾收集器--分区收集器

G1收集器 G1(Garbage-First Garbage Collector)在 JDK 1.7 时引入,在 JDK 9 时取代 CMS 成为了默认的垃圾收集器。G1 有五个属性:分代、增量、并行、标记整理、STW。 分代 G1收集器 将内部分为多个大小相等的区域,另…

Unity Shader 图形学【笔记一】

游戏图形学 源自:计算机图形学 涵盖:图形、动画的创建渲染展示 目标:性能优化、提高视觉质量,增强用户体验 技术:三维模型、纹理、光照、阴影、特效、动画、物理模拟、碰撞检测等 Unity Shader 是:un…

基于Vue+ElementPlus自定义带历史记录的搜索框组件

前言 基于Vue2.5ElementPlus实现的一个自定义带历史记录的搜索框组件 效果如图: 基本样式: 获取焦点后: 这里的历史记录默认最大存储10条,同时右侧的清空按钮可以清空所有历史记录。 同时搜索记录也支持点击搜索,按…

371D - Vessels

思路&#xff1a;用并查集维护&#xff0c;如果当前容器没有满&#xff0c;就指向自己&#xff0c;否则指向下一个容器。 这样就可以快速 find 到下一个没有满的容器&#xff0c;从而模拟询问 1。 代码&#xff1a; void solve(){int n;cin >> n;vector<int>p(n …

leetcode:滑动窗口----3. 无重复字符的最长子串

给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长 子串 的长度。 示例 1: 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc"&#xff0c;所以其长度为 3。示例 2: 输入: s "bbbbb" 输出: 1 解释: 因为…

算法竞赛相关问题总结记录

前言 日常在校生或者是工作之余的同学或多或少都会参加一些竞赛,参加竞赛一方面可以锻炼自己的理解与实践能力&#xff0c;也能够增加自己的生活费&#xff0c;竞赛中的一些方案也可以后续作为自己论文的base,甚至是横向课题的框架。在算法竞赛中算法的差别个人感觉差距都不大&…

一招搞定“找不到xinput1_3.dll,无法继续执行代码”问题

在我们日常使用电脑进行各类工作的过程中&#xff0c;特别是在运行一些关键性软件以完成特定任务时&#xff0c;电脑屏幕上突然弹出一条醒目的错误提示信息&#xff1a;“由于找不到xinput1_3.dll,无法继续执行代码”。这个错误通常发生在使用DirectInput库时&#xff0c;而xin…

BFS解决FloodFill算法:(Leetcode:733. 图像渲染)

题目链接&#xff1a;733. 图像渲染 - 力扣&#xff08;LeetCode&#xff09; 使用广度优先遍历算法解决该问题&#xff1a; 从初始位置开始搜索&#xff0c;初始位置符合条件就入栈&#xff0c;并修改初始位置值。初始位置出栈。 再从初始位置开始广度优先搜索&#xff08;…

阿赵UE学习笔记——30、HUD简单介绍

阿赵UE学习笔记目录 大家好&#xff0c;我是阿赵。   继续学习虚幻引擎&#xff0c;这次来学习一下HUD的基础使用。 一、 什么是HUD HUD(Head-Up Display)&#xff0c;也就是俗称的抬头显示。很多其他领域里面有用到这个术语&#xff0c;比如开车的朋友可能会接触过&#xf…

【Camera Sensor Driver笔记】一、Sensor基本概念

时钟 sensor clock sensor的输入时钟 MCLK 输出时钟&#xff1a; 1. VTPixelClock&#xff1a;会影响sensor内部的帧率、曝光 VTPixelClock(vt_clk)Video Timing Clock, From sensor PLL VTPixelClock Framelengthlines x LinelengthPixelClock x FPS Framelengthlines L…

页面加载事件

2.1窗口加载事件 1.window.οnlοadfuction(){} 或者 window.addEventListerner(‘load’,function(){}) doucument.addEventListner(DOMContentLoaded,fuction(){})这个反应更快些

是德软件89600 RFID使用笔记

文章目录 1、进入RFID软件&#xff1a;2、RFID软件解调设置项3、如何查看一段指令数据 本文是日常工作的笔记分享。 lauch VSA&#xff08;矢量频谱分析&#xff09;后会出现以下界面&#xff1a; 当然这是因为频谱仪的输入有信号才显示如下&#xff1a; 否则就显示频谱仪的噪…

初识C++ · 类和对象(中)(2)

前言&#xff1a;上篇文章已经介绍了6个默认成员函数中的3个函数&#xff0c;分别是构造函数&#xff0c;析构函数&#xff0c;拷贝构造函数&#xff0c;本文介绍的是后三个&#xff0c;赋值运算符重载&#xff0c;const成员函数&#xff0c;取地址操纵符重载。 目录​​​​​…

通过使用XShell工具、Nginx环境实现服务器项目构建与发布

前言&#xff1a; 在信息化和数字化的今天&#xff0c;网站和应用的构建与发布已成为企业发展的重要一环。为了确保项目的顺利上线和稳定运行&#xff0c;选择合适的工具和环境至关重要。本文将详细介绍如何通过XShell工具以及Nginx环境来实现服务器项目的构建与发布&#xff0…

datax介绍和用法

Datax 简介 DataX 是阿里巴巴集团内被广泛使用的离线数据同步工具/平台&#xff0c;实现包括 MySQL、Oracle、SqlServer、Postgre、HDFS、Hive、ADS、HBase、TableStore(OTS)、MaxCompute(ODPS)、DRDS 等各种异构数据源之间高效的数据同步功能。 DataX本身作为数据同步框架&…

智慧图书馆为什么用rfid电子标签而不是磁条

智慧图书馆一般都会使用RFID技术&#xff0c;而不是磁条。以下是几个原因&#xff1a; 1. 效率更高&#xff1a;RFID技术可以实现非接触式读取&#xff0c;图书馆工作人员可以同时读取多本书的信息&#xff0c;大大提高了借还书的效率。 2. 数据量更大&#xff1a;RFID标签可以…

大模型-入门小知识

大模型是什么 大量参数&#xff08;上亿&#xff09;深度学习模型 人工只能包含机器学习&#xff0c;深度学习,深度学习包括大模型 单个神经元的计算模型&#xff1a; 大模型是怎么训练的 之前是算法&#xff08;神经网络&#xff09;----> 训练&#xff08;门槛降低&…