Mistral MOE架构全面解析

从代码角度理解Mistral架构

  • Mistral架构全面解析
    • 前言
    • Mistral 架构分析
      • 分词
      • 网络主干
        • MixtralDecoderLayer
          • Attention
          • MOE
          • MLP
      • 下游任务
        • 因果推理
        • 文本分类

Mistral架构全面解析

前言

Mixtral-8x7B 大型语言模型 (LLM) 是一种预训练的生成式稀疏专家混合模型。在大多数基准测试中,Mistral-8x7B 的性能优于 Llama 2 70B。

Mixtral 8x7B 是 Mistral AI 全新发布的 MoE 模型,MoE 是 Mixture-of-Experts 的简称,具体的实现就是将 Transformer 中的 FFN 层换成 MoE FFN 层,其他部分保持不变。在训练过程中,Mixtral 8x7B 采用了 8 个专家协同工作,而在推理阶段,则仅需激活其中的 2 个专家。这种设计巧妙地平衡了模型的复杂度和推理成本,即使在拥有庞大模型参数的情况下,也能保证高效的推理性能,使得 MoE 模型在保持强大功能的同时,也具备了更优的实用性和经济性。

  • 在大多数基准测试中表现优于Llama 2 70B
  • 甚至足以击败GPT-3.5上下文窗口为32k
  • 可以处理英语、法语、意大利语、德语和西班牙语
  • 在代码生成方面表现优异

huggingface上给出基本的加载方法

from transformers import AutoModelForCausalLM, AutoTokenizermodel_id = "mistralai/Mixtral-8x7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)model = AutoModelForCausalLM.from_pretrained(model_id)text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt")outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

在这里插入图片描述

Mistral 架构分析

其它结构和llama的一模一样,知道llama结构的话,省流直接看MOE部分。

分词

分词部分主要做的是利用文本分词器对文本进行分词

在这里插入图片描述

tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
text = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(text, return_tensors="pt")

网络主干

主干网络部分主要是将分词得到的input_ids输入到embedding层中进行文本向量化,得到hidden_states(中间结果),然后输入到layers层中,得到hidden_states(中间结果),用于下游任务。

在这里插入图片描述

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)self.layers = nn.ModuleList([MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
MixtralDecoderLayer

主干网络的layers层就是由多个MixtralDecoderLayer组成的,由num_hidden_layers参数决定,一般我们说的模型量级就取决于这个数量,7b的模型DecoderLayer层的数量是32。

MixtralDecoderLayer层中又包含了Attention层和MOE层,主要的一个思想是利用了残差结构。

如下图所示,分为两个部分

第一部分

  • 首先,将hidden_states(文本向量化的结构)进行复制,即残差
  • 归一化
  • 注意力层
  • 残差相加

第二部分

  • 首先将第一部分得到的hidden_states进行复制,即残差
  • 归一化
  • MLP层
  • 残差相加

在这里插入图片描述

#复制一份
residual = hidden_states
#归一化
hidden_states = self.input_layernorm(hidden_states)#注意力层
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,padding_mask=padding_mask,
)
#加上残差
hidden_states = residual + hidden_states#复制一份
residual = hidden_states
#归一化
hidden_states = self.post_attention_layernorm(hidden_states)
#mlp
hidden_states = self.mlp(hidden_states)
#加上残差
hidden_states = residual + hidden_statesoutputs = (hidden_states,)if output_attentions:outputs += (self_attn_weights,)if use_cache:outputs += (present_key_value,)return outputs
Attention

进行位置编码,让模型更好的捕捉上下文信息

在这里插入图片描述

#经过线性层
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)#多头注意力形状变换
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]#计算cos、sin
#计算旋转位置嵌入
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)#计算权重
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)#加上掩码
attn_weights = attn_weights + attention_mask
#计算softmax
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)attn_output = self.o_proj(attn_output)
MOE

MOE层,也就是我们的专家模块,简单来说,主要干的就是通过一个线性层,得到8个专家,从这8个专家中选出最专业的2个,把他们的权重相加,输入到MLP层,得到最终的结果。

  • attention层得到的hidden_states经过控制门(nn.Linear)得到8个输出。(有点像多分类)
  • t通过softmax计算8个输出的概率值
  • 从8个中选择概率值最高的两个专家
  • 概率最高的两个专家进行权重相加,并计算相对概率值
  • 这两个专家输入到MLP层中进行一系列计算得到最后结果

在这里插入图片描述

batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
#这里通过一个线性层,得到8个输出(n_experts),也就是所谓的专家。
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
#这里通过softmax计算8个输出的概率值,如(0.2,0.3,0.0833,0.0833,0.0833,0.0833,0.0833,0.0833)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
#从8个中选择概率值最高的两个专家((0.2,0.3)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
#概率最高的两个专家进行权重相加,并计算相对概率值((0.2,0.3)->(0.4,0.6)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
#初始化最终结果
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
#掩码
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)for expert_idx in range(self.num_experts):expert_layer = self.experts[expert_idx]#通过掩码找到top2的位置idx, top_x = torch.where(expert_mask[expert_idx])if top_x.shape[0] == 0:continuetop_x_list = top_x.tolist()idx_list = idx.tolist()#top2对应的向量current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)#经过mlpcurrent_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]#加到final_hidden_states中final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
MLP

mlp层的主要作用是应用非线性激活函数和线性投影。

  • 首先将attention层得到的结果经过两个线性层得到gate_proj和up_proj
  • gate_proj经过激活函数,再和up_proj相乘
  • 最后经过一个线性层得到最后的结果

在这里插入图片描述

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

下游任务

因果推理

所谓因果推理,就是回归任务。

在这里插入图片描述

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
文本分类

即分类任务

在这里插入图片描述

self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/228616.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探索顺序表:数据结构中的秩序之美(c语言实现常见功能接口)

在我们的数据结构探索中,我们已经探讨时间复杂度、空间复杂度。大家可以移步到我的上篇文章: 打开数据结构大门:深入理解时间与空间复杂度 今天,我们将深入研究另一个重要的主题——顺序表 全部的源代码大家可以去我github主页…

web服务器之——www服务器的基本配置

目录 一、www简介 1、什么是www 2、www所用的协议 3、WEB服务器 4、主要数据 5、浏览器 二、 网址及HTTP简介 1、HTTP协议请求的工作流程 三、www服务器的类型(静态网站(HTML), 动态网站(jsp python,php,perl)) 1、 仅提供…

Windows设备管理

1、前言 熟悉Windows系统的都应该使用过设备管理器。设备管理器将操作系统中所有已安装的设备分类展现出来。同时提供了安装、卸载、启用和禁用的功能。 那么,我们应该如何通过C编程的方式实现这种功能呢?答案很简单,那就是使用SetupDi函数族…

Lumerical 选项------superimpose structure

Lumerical 选项------superimpose structure 简介正文 简介 这里给大家介绍一下 Modal analysis 计算中的 superimpose structure 选项的作用。 正文 当我们勾选上 superimpose structure 选项时, 当我们取消勾选时 通过对比我们得到,勾选 superimp…

Windows11环境下配置深度学习环境(Pytorch)

目录 1. 下载安装Miniconda2. 新建Python3.9虚拟环境3. 下载英伟达驱动4. 安装CUDA版Pytorch5. CPU版本pytorch安装 1. 下载安装Miniconda 下载安装包:镜像文件地址 将Miniconda相关路径添加至系统变量的路径中。 打开Anaconda Powershell Prompt,输入…

计算机组成原理-指令系统CISC和RISC

文章目录 总览CISC和RISC 总览 CISC和RISC 存储程序就是用一个电路再加上存储部件构成 可访存指令不同 RISC更自由,因为很多函数没有固定,是自己写的 由于CISC各个指令执行时间不一样,要实现指令流水线比较困难 由于CISC可访存指令没有限制…

游戏、算法竞赛与退役(流水账版)

写在前面 不出意外的话,这东西本该咕到翻年之后再发的,但好像催稿催的有点厉害,于是就找个机会把他写了(笑) 最初是只想写个算法竞赛退役记的,后面发觉写起来就有点收不住,算法竞赛牵扯到太多…

CSS margin-trim

margin-trim 主角登场主角的局限性兼容性 margin-trim &#x1f9ea;这是一个实验性的属性, 目前仅有 Safari 支持 看这个属性的名字就知道, 外边距修剪. 平常都会遇到一些排版上的问题, 比如垂直排列的元素之间增加下外边距 <div><li>123</li><li>…

JAVA序列化(创建可复用的 Java 对象)

JAVA 序列化(创建可复用的 Java 对象) 保存(持久化)对象及其状态到内存或者磁盘 Java 平台允许我们在内存中创建可复用的 Java 对象&#xff0c;但一般情况下&#xff0c;只有当 JVM 处于运行时&#xff0c;这些对象才可能存在&#xff0c;即&#xff0c;这些对象的生命周期不…

算法竞赛备赛进阶之树形DP训练

目录 1.树的最长路径 2.树的中心 3.数字转换 4.二叉苹果树 5.战略游戏 6.皇宫守卫 树形DP是一种动态规划方法&#xff0c;主要用于解决树形结构的问题。在树形DP中&#xff0c;通常会使用动态规划的思想来求解最优化问题。其核心在于通过不断地分解问题和优化子问题来解决…

2023年国家基地“楚慧杯”网络空间安全实践能力竞赛 Web方向 题解wp

前言&#xff1a;三小时的比赛&#xff0c;和强网同时结束还要当场交wp&#xff0c;汗流浃背&#xff0c;烧起来了啊啊啊啊~ eaaeval 目录扫出备份文件 源码如下 <?php class Flag{public $a;public $b;public function __construct(){$this->a admin;$this->b …

c++字符串和日期基础

一&#xff0c;字母三角形 #include<string> #include<iostream> using namespace std; int main() {int n 0;cin >> n;for (int i 1; i < n; i)//i代表行数{string spacestring(n - i, );//前半部分空格string ch string(2 * i - 1, A i - 1);cout…

工具在手,创作无忧:一键下载安装Auto CAD工具,让艺术创作更加轻松愉悦!

不要再浪费时间在网上寻找Auto CAD的安装包了&#xff01;因为你所需的一切都可以在这里找到&#xff01;作为全球领先的设计和绘图软件&#xff0c;Auto CAD为艺术家、设计师和工程师们提供了无限的创作潜力。不论是建筑设计、工业设计还是室内装饰&#xff0c;Auto CAD都能助…

《Linux C编程实战》笔记:文件属性操作函数

获取文件属性 stat函数 在shell下直接使用ls就可以获得文件属性&#xff0c;但是在程序里应该怎么获得呢&#xff1f; #include<sys/types.h> #include <sys/stat.h> #include <unistd.h> int stat(const char *file_name,struct stat *buf); int fstat(i…

【eNSP实验项目】eNSP实验配置项目教程,ensp安装步骤

eNSP安装教程 附安装包 eNSP介绍安装教程1.安装 VirtualBox2.安装 WinPcap3.安装 Wireshark4.eNSP安装 eNSP介绍 eNSP是华为提供的一款功能强大的网络仿真平台&#xff0c;适用于学习、实践和测试企业网络场景&#xff0c;可以帮助用户深入理解网络知识和技术。 eNSP安装,需要…

Tektronix泰克TCP303示波器电流探头

主要特点和优点&#xff1a; ● 交流/直流测量功能 ● DC~100MHz电流探头放大器&#xff08;TCPA300&#xff09;&#xff0c;当使用&#xff1a; - DC~100MHz, 30A DC&#xff08;TCP312&#xff09; - DC~50MHz, 50A DC&#xff08;TCP305&#xff09; - DC~5MHz, 150A DC&a…

关于多重背包的笔记

多重背包可以看作01背包的拓展&#xff0c; 01背包是选或者不选。多重背包是选0个一直到选s个。 for (int i 1; i < n; i) {for (int j m; j > w[i]; --j){f[j] max(f[j], f[j - 1*w[i]] 1*v[i], f[j - 2*w[i]] 2*v[i],...f[j - s*w[i]] s*v[i]);} } 由上述伪代码…

Mybatis-plus是使用,告别繁琐的CRUD编写,自动生成直接使用

目录 一、简介 1. 是什么 2. 特性 3. 框架结构 4. 常用注解 二、搭建使用 1. 依赖 2. 生成器 3. 生成 4. 引用 5. 路径访问 三、测试 四、雪花ID 每篇一获 Mybatis-plus&#xff08;简称 MP&#xff09;是一个 MyBatis (opens new window)的增强工具&#xff0c;…

VRRP协议

一.基本概念 1.概念 VRRP能够在不改变组网的情况下&#xff0c;将多台路由器虚拟成一个虚拟路由器&#xff0c;通过配置虚拟路由器的IP地址为默认网关&#xff0c;实现网关的备份。协议版本&#xff1a;VRRPv2&#xff08;常用&#xff09;和VRRPv3&#xff1a;VRRPv2仅适用于…

【基于卷积神经网络的疲劳检测与预警系统的设计与实现】

基于卷积神经网络的疲劳检测与预警系统的设计与实现 引言数据集介绍技术与工具1. OpenCV2. TensorFlow3. 卷积神经网络&#xff08;CNN&#xff09; 系统功能模块1. 视频采集模块2. 图像预处理模块3. 人脸识别模块4. 疲劳程度判别模块5. 报警模块 系统设计创新点1. 实时监测与预…