Llama 架构分析

从代码角度进行Llama 架构分析

  • Llama 架构分析
    • 前言
    • Llama 架构分析
      • 分词
      • 网络主干
        • DecoderLayer
          • Attention
          • MLP
      • 下游任务
        • 因果推理
        • 文本分类

Llama 架构分析

前言

Meta 开发并公开发布了 Llama系列大型语言模型 (LLM),这是一组经过预训练和微调的生成文本模型,参数规模从 70 亿到 700 亿不等。

在大多数任务中,LLaMA-13B要比GPT-3(175B)的性能要好,LLaMA-65B和组好的模型Chinchilla-70B以及PaLM-540B的实力相当。

Llama 架构分析

分词

分词部分主要做的是利用文本分词器对文本进行分词

在这里插入图片描述

tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
text = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(text, return_tensors="pt")

网络主干

主干网络部分主要是将分词得到的input_ids输入到embedding层中进行文本向量化,得到hidden_states(中间结果),然后输入到layers层中,得到hidden_states(中间结果),用于下游任务。

在这里插入图片描述

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)self.layers = nn.ModuleList([MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
DecoderLayer

主干网络的layers层就是由多个DecoderLayer组成的,由num_hidden_layers参数决定,一般我们说的模型量级就取决于这个数量,7b的模型DecoderLayer层的数量是32。

DecoderLayer层中又包含了Attention层和MLP层,主要的一个思想是利用了残差结构。

如下图所示,分为两个部分

第一部分

  • 首先,将hidden_states(文本向量化的结构)进行复制,即残差
  • 归一化
  • 注意力层
  • 残差相加

第二部分

  • 首先将第一部分得到的hidden_states进行复制,即残差
  • 归一化
  • MLP层
  • 残差相加

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

#复制一份
residual = hidden_states
#归一化
hidden_states = self.input_layernorm(hidden_states)#注意力层
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,padding_mask=padding_mask,
)
#加上残差
hidden_states = residual + hidden_states#复制一份
residual = hidden_states
#归一化
hidden_states = self.post_attention_layernorm(hidden_states)
#mlp
hidden_states = self.mlp(hidden_states)
#加上残差
hidden_states = residual + hidden_statesoutputs = (hidden_states,)if output_attentions:outputs += (self_attn_weights,)if use_cache:outputs += (present_key_value,)return outputs
Attention

进行位置编码,让模型更好的捕捉上下文信息

在这里插入图片描述

#经过线性层
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)#多头注意力形状变换
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]#计算cos、sin
#计算旋转位置嵌入
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)#计算权重
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)#加上掩码
attn_weights = attn_weights + attention_mask
#计算softmax
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)attn_output = self.o_proj(attn_output)
MLP

mlp层的主要作用是应用非线性激活函数和线性投影。

  • 首先将attention层得到的结果经过两个线性层得到gate_proj和up_proj
  • gate_proj经过激活函数,再和up_proj相乘
  • 最后经过一个线性层得到最后的结果

在这里插入图片描述

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

下游任务

因果推理

所谓因果推理,就是回归任务。

在这里插入图片描述

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
文本分类

即分类任务

在这里插入图片描述

self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227392.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp常用api讲解

Uniapp是一个基于Vue.js的跨平台开发框架,可以同时开发微信小程序、H5、App等多个平台的应用。下面是Uniapp常用的API讲解: Vue.js的API Uniapp采用了Vue.js框架,因此可以直接使用Vue.js的API。例如:v-show、v-if、v-for、compu…

二蛋赠书八期:《Java物联网、人工智能和区块链编程实战》

前言 大家好!我是二蛋,一个热爱技术、乐于分享的工程师。在过去的几年里,我一直通过各种渠道与大家分享技术知识和经验。我深知,每一位技术人员都对自己的技能提升和职业发展有着热切的期待。因此,我非常感激大家一直…

深入剖析NPM: Node包管理器的介绍和使用指南

导言:NPM(Node Package Manager)是JavaScript世界中最受欢迎的包管理器之一。它的出现大大简化了JavaScript开发过程中的依赖管理和模块化。本文将向您介绍NPM的基本概念、功能和常见用法,并为您提供一份详尽的NPM使用指南。 一、…

深度学习优化器Optimizer-SGD、mSGD、AdaGrad、RMSProp、Adam、AdamW

Optimizer 优化 学习率 l e a r n i n g r a t e : α 学习率learning\;rate: \alpha 学习率learningrate:α 防止除 0 的截断参数 : ϵ 防止除0的截断参数: \epsilon 防止除0的截断参数:ϵ t 时刻的参数 : W t t\;时刻的参数: W_{t} t时刻的参数:Wt​ t 时刻的梯度&#xf…

【改进YOLOv8】电动车电梯入户检测系统:融合HGNetv2改进改进YOLOv8

1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 研究背景与意义: 随着电动车的普及和人们对环境保护的重视,电动车的使用量逐渐增加。然而,电动车的充电问题一直是一个挑战,特别是…

JavaScript中while循环语句

循环语句(loop) - 通过循环语句可以让一段代码反复的执行多次 - 循环语句主要两种: while语句(while循环) do-while语句 for语句(for循环) while语句: - 语法: wh…

React系列:useEffect的使用

useEffect的使用 useEffect的第二个参数不同,useEffect的加载不同 当第二个参数为没有的时候 只在组件初始渲染和组件更新之后加载当第二个参数为[] 的时候 只在初始渲染之后加载当第二个参数为[有依赖] 的时候 只在初始渲染之后和依赖修改的时候进行加载 functi…

Spark报错处理系列之:Caused by: java.lang.StackOverflowError

Spark报错处理系列之:Caused by: java.lang.StackOverflowError 一、完整报错二、错误原因三、解决方法一、完整报错 INFO ApplicationMaster: Unregistering ApplicationMaster with FAILED (diag message: User class threw exception: org.apache.spark.SparkException: Jo…

贝蒂详解<string.h>哦~(用法与实现)

目录 引言: (一)字符函数和字符串函数 1.简介 2.strlen()函数 2.1用法 2.2实例 2.3 实现strlen() (1)计数法 (2)递归法 (3) 指针-指针 2.4sizeof和strlen()的区别 3.s…

PhpStorm下载、安装、配置教程

前面的文章中,都是把.php文件放在WampServer的www目录下,通过浏览器访问运行。这篇文章就简单介绍一下PhpStorm这个php集成开发工具的使用。 目录 下载PhpStorm 安装PhpStorm 配置PhpStorm 修改个性化设置 修改字符编码 配置php的安装路径 使用Ph…

网络基础3

NAT(Network Address Translation):网络地址转换 通过将内部网络的私有IP地址装换成全球唯一的公网IP地址,使内部网络可以连接到互联网。 广域网就是外网,局域网就是内网 私有IP地址:(如果是纯内…

Flask基本用法:一个HelloWorld,搭建服务、发起请求

目录 1、简介 2、安装 3、Flask使用示例 参考 1、简介 官网文档 Flask是一个轻量的web服务框架,我们可以利用它快速搭建一个服务,对外提供接口,其他人可以轻松调用我们的服务。这对算法工程师来说比较关键,我们通常不擅长搞开发…

极坐标下的牛拉法潮流计算14节点MATLAB程序

微❤关注“电气仔推送”获得资料(专享优惠) 潮流计算: 潮流计算是根据给定的电网结构、参数和发电机、负荷等元件的运行条件,确定电力系统各部分稳态运行状态参数的计算。通常给定的运行条件有系统中各电源和负荷点的功率、枢纽…

JRT实现原生Webservice发布

之前准备试试Java发布Webservice,开始以为很简单,因为C#发布很简单。后面发现太费劲了,依赖一堆包,下面几种都试了一下: JAX-WS (Java API for XML Web Services):这是Java EE平台的标准,用于创…

nodejs微信小程序+python+PHP的微博网络舆情分析系统-计算机毕业设计推荐

(4)微博信息交流:在首页导航栏上我们会看到“微博信息交流”这一菜单,我们点击进入进去以后,会看到所有管理员在后台发布的交流信息; (5)新闻资讯:用户可以查看新闻资讯信…

【面试】写一个删除列表中重复元素的函数,要求去重后元素相对位置保持不变

def dedup(items):no_dup_items []seen set()for item in items:if item not in seen:no_dup_items.append(item)seen.add(item)return no_dup_items 如果愿意也可以把上面的函数改造成一个生成器,代码如下所示: def dedup(items):seen set()for it…

【STM32入门】4.2对射红外传感器计次

1.接线方式 主要是编写传感器的驱动、配合OLED,每遮挡对射红外传感器,OLED屏幕的计数就加一。 2.驱动编写 首先新建.c文件和.h文件,命名为CountSensor 国际惯例,.c文件内要包含stm32.h头文件,然后编写 CountSensor_…

在Linux上安装配置Nginx高性能Web服务器

1 前言 Nginx是一个高性能的开源Web服务器,同时也可以作为反向代理服务器、负载均衡器、HTTP缓存以及作为一个邮件代理服务器。它以其出色的性能和灵活性而闻名,被广泛用于处理高流量的网站和应用程序。本文将介绍在Linux环境中安装Nginx的步骤&#xf…

CSS新手入门笔记整理:CSS3选择器

属性选择器 属性选择器,指的是通过“元素的属性”来选择元素的一种方式。 语法 元素[attr^"xxx"]{} 元素[attr$"xxx"]{} 元素[attr*"xxx"]{} 选择器 说明 E[attr^"xxx"] 选择元素E,其中E元素的attr属性是…

new一个对象

1.自己直接调用 function Person(name, age) {this.name name;this.age age;}let a1 new Person("小明", 20);let a2 new Person("小菜", 25);console.log(a1); 打印的对象: 2.自己模拟一个 function Person(name, age) {this.name name;this.age a…