Llama 架构分析

从代码角度进行Llama 架构分析

  • Llama 架构分析
    • 前言
    • Llama 架构分析
      • 分词
      • 网络主干
        • DecoderLayer
          • Attention
          • MLP
      • 下游任务
        • 因果推理
        • 文本分类

Llama 架构分析

前言

Meta 开发并公开发布了 Llama系列大型语言模型 (LLM),这是一组经过预训练和微调的生成文本模型,参数规模从 70 亿到 700 亿不等。

在大多数任务中,LLaMA-13B要比GPT-3(175B)的性能要好,LLaMA-65B和组好的模型Chinchilla-70B以及PaLM-540B的实力相当。

Llama 架构分析

分词

分词部分主要做的是利用文本分词器对文本进行分词

在这里插入图片描述

tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
text = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(text, return_tensors="pt")

网络主干

主干网络部分主要是将分词得到的input_ids输入到embedding层中进行文本向量化,得到hidden_states(中间结果),然后输入到layers层中,得到hidden_states(中间结果),用于下游任务。

在这里插入图片描述

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)self.layers = nn.ModuleList([MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
DecoderLayer

主干网络的layers层就是由多个DecoderLayer组成的,由num_hidden_layers参数决定,一般我们说的模型量级就取决于这个数量,7b的模型DecoderLayer层的数量是32。

DecoderLayer层中又包含了Attention层和MLP层,主要的一个思想是利用了残差结构。

如下图所示,分为两个部分

第一部分

  • 首先,将hidden_states(文本向量化的结构)进行复制,即残差
  • 归一化
  • 注意力层
  • 残差相加

第二部分

  • 首先将第一部分得到的hidden_states进行复制,即残差
  • 归一化
  • MLP层
  • 残差相加

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

#复制一份
residual = hidden_states
#归一化
hidden_states = self.input_layernorm(hidden_states)#注意力层
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,padding_mask=padding_mask,
)
#加上残差
hidden_states = residual + hidden_states#复制一份
residual = hidden_states
#归一化
hidden_states = self.post_attention_layernorm(hidden_states)
#mlp
hidden_states = self.mlp(hidden_states)
#加上残差
hidden_states = residual + hidden_statesoutputs = (hidden_states,)if output_attentions:outputs += (self_attn_weights,)if use_cache:outputs += (present_key_value,)return outputs
Attention

进行位置编码,让模型更好的捕捉上下文信息

在这里插入图片描述

#经过线性层
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)#多头注意力形状变换
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]#计算cos、sin
#计算旋转位置嵌入
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)#计算权重
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)#加上掩码
attn_weights = attn_weights + attention_mask
#计算softmax
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)attn_output = self.o_proj(attn_output)
MLP

mlp层的主要作用是应用非线性激活函数和线性投影。

  • 首先将attention层得到的结果经过两个线性层得到gate_proj和up_proj
  • gate_proj经过激活函数,再和up_proj相乘
  • 最后经过一个线性层得到最后的结果

在这里插入图片描述

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

下游任务

因果推理

所谓因果推理,就是回归任务。

在这里插入图片描述

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
文本分类

即分类任务

在这里插入图片描述

self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227392.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二蛋赠书八期:《Java物联网、人工智能和区块链编程实战》

前言 大家好!我是二蛋,一个热爱技术、乐于分享的工程师。在过去的几年里,我一直通过各种渠道与大家分享技术知识和经验。我深知,每一位技术人员都对自己的技能提升和职业发展有着热切的期待。因此,我非常感激大家一直…

【改进YOLOv8】电动车电梯入户检测系统:融合HGNetv2改进改进YOLOv8

1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 研究背景与意义: 随着电动车的普及和人们对环境保护的重视,电动车的使用量逐渐增加。然而,电动车的充电问题一直是一个挑战,特别是…

贝蒂详解<string.h>哦~(用法与实现)

目录 引言: (一)字符函数和字符串函数 1.简介 2.strlen()函数 2.1用法 2.2实例 2.3 实现strlen() (1)计数法 (2)递归法 (3) 指针-指针 2.4sizeof和strlen()的区别 3.s…

PhpStorm下载、安装、配置教程

前面的文章中,都是把.php文件放在WampServer的www目录下,通过浏览器访问运行。这篇文章就简单介绍一下PhpStorm这个php集成开发工具的使用。 目录 下载PhpStorm 安装PhpStorm 配置PhpStorm 修改个性化设置 修改字符编码 配置php的安装路径 使用Ph…

网络基础3

NAT(Network Address Translation):网络地址转换 通过将内部网络的私有IP地址装换成全球唯一的公网IP地址,使内部网络可以连接到互联网。 广域网就是外网,局域网就是内网 私有IP地址:(如果是纯内…

Flask基本用法:一个HelloWorld,搭建服务、发起请求

目录 1、简介 2、安装 3、Flask使用示例 参考 1、简介 官网文档 Flask是一个轻量的web服务框架,我们可以利用它快速搭建一个服务,对外提供接口,其他人可以轻松调用我们的服务。这对算法工程师来说比较关键,我们通常不擅长搞开发…

极坐标下的牛拉法潮流计算14节点MATLAB程序

微❤关注“电气仔推送”获得资料(专享优惠) 潮流计算: 潮流计算是根据给定的电网结构、参数和发电机、负荷等元件的运行条件,确定电力系统各部分稳态运行状态参数的计算。通常给定的运行条件有系统中各电源和负荷点的功率、枢纽…

JRT实现原生Webservice发布

之前准备试试Java发布Webservice,开始以为很简单,因为C#发布很简单。后面发现太费劲了,依赖一堆包,下面几种都试了一下: JAX-WS (Java API for XML Web Services):这是Java EE平台的标准,用于创…

nodejs微信小程序+python+PHP的微博网络舆情分析系统-计算机毕业设计推荐

(4)微博信息交流:在首页导航栏上我们会看到“微博信息交流”这一菜单,我们点击进入进去以后,会看到所有管理员在后台发布的交流信息; (5)新闻资讯:用户可以查看新闻资讯信…

【STM32入门】4.2对射红外传感器计次

1.接线方式 主要是编写传感器的驱动、配合OLED,每遮挡对射红外传感器,OLED屏幕的计数就加一。 2.驱动编写 首先新建.c文件和.h文件,命名为CountSensor 国际惯例,.c文件内要包含stm32.h头文件,然后编写 CountSensor_…

在Linux上安装配置Nginx高性能Web服务器

1 前言 Nginx是一个高性能的开源Web服务器,同时也可以作为反向代理服务器、负载均衡器、HTTP缓存以及作为一个邮件代理服务器。它以其出色的性能和灵活性而闻名,被广泛用于处理高流量的网站和应用程序。本文将介绍在Linux环境中安装Nginx的步骤&#xf…

new一个对象

1.自己直接调用 function Person(name, age) {this.name name;this.age age;}let a1 new Person("小明", 20);let a2 new Person("小菜", 25);console.log(a1); 打印的对象: 2.自己模拟一个 function Person(name, age) {this.name name;this.age a…

[Linux] LVS负载均衡群集——DR模式

一、 DR模式的特点 直接路由: 在LVS_DR模式下,负载均衡器不修改数据包的IP地址,只修改目的MAC地址。这使得数据包可以直接路由到后端实际服务器上,而不需要返回到负载均衡器。 高性能: 由于数据包在传输过程中不需要回…

本地运行大语言模型并可视化(Ollama+big-AGI方案)

目前有两种方案支持本地部署,两种方案都是基于llamacpp。其中 Ollama 目前只支持 Mac,LM Studio目前支持 Mac 和 Windows。 LM Studio:https://lmstudio.ai/ Ollama:https://ollama.ai/download 本文以 Ollama 为例 step1 首先下…

STM32_启动流程详解

目录标题 前言 启动流程概述复位中断函数详解SystemInit函数详解 __main函数详解 附录 stm32单片机的存储器映像中断向量表的映射 前言 最近在学习IAP远程OTA升级单片机固件程序,发现自己对单片机的启动流程还不是那么了解,就总结整理一下吧。 启动流程…

QT实现四则运算计算器

#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);this->setMaximumSize(240,300);this->setMinimumSize(240,300);this->setWindowTitle("计算器&…

node.js mongoose简述

目录 官方文档 mongoose Schema Model Query document 关系 官方文档 Mongoose v8.0.3: Getting Started mongoose Mongoose 是一个 Node.js 环境下 MongoDB 的对象建模工具。它提供了一种在应用程序中与 MongoDB 数据库进行交互的方式,使得开发者能够使用…

NoSQL 数据库有哪些典型应用?

前面的内容介绍了数据库读写分离和分库分表相关知识,都是针对关系型数据库的,即通常说的 RDBMS。除了关系型数据库,NoSQL 在项目开发中也有着越来越重要的作用,与此同时,NoSQL 相关的内容也是面试的常客。今天我们一起…

函数难题:排列

给定一个整数 n,将数字 1∼n 排成一排,将会有很多种排列方法。 现在,请你按照字典序将所有的排列方法输出。 输入格式 共一行,包含一个整数 n。 输出格式 按字典序输出所有排列方案,每个方案占一行。 数据范围 …

【Linux】驱动

驱动 驱动程序过程 系统调用 用户空间 内核空间 添加驱动和调用驱动 驱动程序是如何调用设备硬件 驱动 在计算机领域,驱动(Driver)是一种软件,它充当硬件设备与操作系统之间的桥梁,允许它们进行通信和协同工作。驱动程…