LLM面面观之MoE

1. 背景

根据本qiang~最新的趋势观察,基于MoE架构的开源大模型越来越多,比如马斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等,因此想探究一下MoE里面的部分细节。

此文是本qiang~针对大语言模型的MoE的整理,包括原理、流程及部分源码

2. MoE原理

MoE的流行源于”欧洲的OpenAI” Mistral AI发布的论文及模型《Mixtral of Experts》,评测集上的效果吊打众多开源模型,如Llama 2 70B和GPT3.5。

《Mixtral of Experts》基础模型使用的是Mistral AI自研的Mistral 7B,该模型的特点包括:滑窗注意力(Sliding Window Aattention), 滚动缓冲区缓存(Rolling Buffer Cache)以及预填充-分块(Pre-fill and Chunking),具体细节可以查阅文末的论文地址。

本文以《Mixtral of Experts》为引子,探究MoE的相关细节,MoE的原理如下图所示:

图2.1 MoE的原理

(1) Transformers架构中的每一层中的FFN网络均替换为了8个FFN(专家),且由一个网关路由(gate router)进行控制

(2) 针对每一个token,每一层的网关路由仅选择其中的2个FFN(专家)来处理当前状态并进行加权输出

(3) 结果就是,每一个token访问了47B参数,但是在推理阶段仅仅使用了13B的激活参数(即,只使用2个专家,冻结其他6个专家)。

(4) 与Dropout机制对比,Dropout让部分神经元失活,而MoE是让部分专家失活。

3. 源码

本qiang~研读并尝试执行了Mistral官网的github推理代码,该代码框架非常适合新手,无他,只因其几乎只是在torch上层做的封装,很少引擎其他第三方库,不像transformers,功能强大,但不适合新手研读代码…

为了普适性,下面的代码截取了transformers框架中的代码。

首先看下通用Transformers中FFN中的代码模块,代码位置在transformers.models.mistral.modeling_mistral, 主要流程是:

(1) 先经过gate_proj和up_proj的2个[hidden_size, intermediate_size]的线性转换

(2) 使用激活函数对gate_proj进行激活

(3) 二者的内积再经过down_proj线性转换。

class MistralMLP(nn.Module):def __init__(self, config):super().__init__()self.config = configself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

再来看下MoE中的专家模块,代码位置在transformers.models.mixtral.modeling_mixtral,主要流程是:

(1) 首先经过网关路由self.gate

(2) 然后选择其中2个专家,并归一化

(3) 之后遍历每个专家网络,并按照expert_mask进行筛选

(4) 如果expert_mask有值,则选择指定部分的隐藏层进行FFN操作,且输出结果进行加权

(5) 最后原地增加先前初始化的最终结果变量final_hidden_states

class MixtralSparseMoeBlock(nn.Module):def __init__(self, config):super().__init__()self.hidden_dim = config.hidden_sizeself.ffn_dim = config.intermediate_sizeself.num_experts = config.num_local_expertsself.top_k = config.num_experts_per_tok# gatingself.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:""" """batch_size, sequence_length, hidden_dim = hidden_states.shapehidden_states = hidden_states.view(-1, hidden_dim)# router_logits: (batch * sequence_length, n_experts)router_logits = self.gate(hidden_states)routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)routing_weights /= routing_weights.sum(dim=-1, keepdim=True)# we cast back to the input dtyperouting_weights = routing_weights.to(hidden_states.dtype)final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)# One hot encode the selected experts to create an expert mask# this will be used to easily index which expert is going to be sollicitatedexpert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)# Loop over all available experts in the model and perform the computation on each expertfor expert_idx in range(self.num_experts):expert_layer = self.experts[expert_idx]idx, top_x = torch.where(expert_mask[expert_idx])if top_x.shape[0] == 0:continue# in torch it is faster to index using lists than torch tensorstop_x_list = top_x.tolist()idx_list = idx.tolist()# Index the correct hidden states and compute the expert hidden state for# the current expert. We need to make sure to multiply the output hidden# states by `routing_weights` on the corresponding tokens (top-1 and top-2)current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]# However `index_add_` only support torch tensors for indexing so we'll use# the `top_x` tensor here.final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)return final_hidden_states, router_logits

其中MixtralBlockSparseTop2MLP代码如下,可以看到和传统MistralMLP内容完全一致。

class MixtralBlockSparseTop2MLP(nn.Module):def __init__(self, config: MixtralConfig):super().__init__()self.ffn_dim = config.intermediate_sizeself.hidden_dim = config.hidden_sizeself.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)self.act_fn = ACT2FN[config.hidden_act]def forward(self, hidden_states):current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)current_hidden_states = self.w2(current_hidden_states)return current_hidden_states

4. MoE微调

由于MoE只是将每一层的FFN改变为了每一层的gate网关路由+8个FFN专家,且gate网关路由和8个专家内部均为线性运算,所以可以无缝地结合LoRA、QLoRA进行指令微调。

可以参考开源项目:https://github.com/yangjianxin1/Firefly

5. 答疑解惑

(1) 问:MoE 8*7B的模型是56B参数?

答:MoE 8*7B的参数量是47B,而不是56B,原因是每一层除了8个专家网络外,其他层均是复用的。

(2) 问:MoE的基础模型是Mistral 7B?

答:不是,MoE的模型架构与Mistral 7B相同,但其中的FFN替换为了8个FFN,且MoE是基于多语言数据集预训练而来的。

(3) MoE的稀疏性(sparse)体现在哪里?

答:在训练和推理时,同时只有两个专家网络会被激活,进行前向计算,其它专家网络处于失活状态。

6. 总结

一句话足矣~

本文主要针对大语言模型的MoE,包括原理及部分源码。

此外,建议大家可以针对源码进行运行,关于源码,欢迎大家一块交流。

7. 参考

(1) Mistral 7B:https://arxiv.org/pdf/2310.06825v1.pdf

(2) MoE: https://arxiv.org/pdf/2401.04088v1.pdf

(3) MoE开源指令微调框架Firefly: https://github.com/yangjianxin1/Firefly

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/786946.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mybatis——查询数据

查询操作 根据用户id查询单条记录,在映射器接口(UserMapper)中定义如下方法: package org.example.mapper;import org.example.demo.User;import java.util.List;public interface UserMapper {//根据id查询UserUser selectUserById(Integer userId); …

stable diffusion 的 GPU 不足怎么解决

稳定扩散(stable diffusion)是一种用于图像处理和计算机视觉任务的图像滤波算法。 当使用Stable Diffusion过程中遇到GPU显示内存不足的问题时。解决这个问题的方法有以下几种: 目前,对我来说,就最后一点能够暂时解决当…

GaussDB云数据库极简版安装与使用-新手指南

一、前言 作为一款领先的企业级数据库管理系统,GaussDB 提供了强大的性能、高度可靠性和丰富的功能,是企业构建可靠、高性能的数据库解决方案的理想选择。 本文主要针对高校和个人测试环境,介绍极简版安装和使用过程,更加适合高…

SwiftUI Swift 选择图片 添加图片

1. 添加记帐时添加图片功能 2. Show me the code // // TestPhotoPicker.swift // pandabill // // Created by 朱洪苇 on 2024/3/30. //import SwiftUI import PhotosUI import Foundationstruct TestPhotoPicker: View {State private var selectedItem: PhotosPickerIt…

Php_Code_challenge12

题目: 答案: 解析: 字符串拼接。

文献阅读:通过 NeuronChat 从单细胞转录组推断神经元-神经元通信

文献介绍 「文献题目」 Inferring neuron-neuron communications from single-cell transcriptomics through NeuronChat 「研究团队」 聂青(加利福尼亚大学欧文分校) 「发表时间」 2023-02-28 「发表期刊」 Nature Communications 「影响因子」 16.6…

15 - grace序列处理 - 十三点滑动平均法

grace序列处理 -十三点滑动平均法 滑动平均是一种常用的平滑数据的方法,可以用于去除噪声或者提取趋势。十三点滑动平均是指使用窗口大小为13的滑动平均,应用于GRACE序列处理中可以去除周年项的影响。 十三点滑动平均的计算公式为: y [ n ] = ( x [ n − 6 ]

互联网轻量级框架整合之JavaEE基础I

不得不解释得几个概念 JavaEE SUN公司提出来的企业版Java开发中间件,主要用于企业级互联网系统的框架搭建,同时因为Java语言优质的平台无关性、可移植性、健壮性、支持多线程和安全性等优势,其迅速成为构建企业互联网平台的主流技术&#x…

基于UML的系统分析与设计

统一建模语言(Unified Modeling Language,UML)是一种为面向对象系统的产品进行说明、可视化和编制文档的一种标准语言,是非专利的第三代建模和规约语言。UML是面向对象设计的建模工具,独立于任何具体程序设计语言。 毕业设计是实现本科教学培…

Php_Code_challenge16

题目: 答案: 解析: 所以科学计数法绕过即可。

macOS Sonoma 14.4 23E214 VMware系统包下载地址,简单便捷,导入即可用!

这回分享的是VMware虚拟机macOS 14.4版本的系统包,这种系统包是已经在VMware虚拟机中安装好了的macOS系统。省去了繁琐的安装步骤与稍微漫长的等待时间。此次更新的包为诗林工作室制作的最新一个VMware系统包版本。分享给那些想快速体验macOS 14版本的朋友。 使用方…

C++ AVL树(旋转)

我们之前学习了搜索二叉树,我们知道普通的搜索二叉树会有特殊情况出现使得二叉树的两枝极其不平衡形成我们通俗说的歪脖子树: 这样的树一定会使得我们的增删查的效率变低;为了避免这种极端的情况出现,在1962年有两位伟大的俄罗斯数…

EasyExcel 复杂表头的导出(动态表头和静态表头)

问题:如图,1部分的表头是动态的根据日期变化,2部分是数据库对应的字段,静态不变的; 解决方案:如果不看1的部分,2部分内容可以根据实体类注解的方式导出,那么我们是不是可以先将动态表…

Centos7 安装 Oracle19c

下载oracle预安装包 wget http://yum.oracle.com/repo/OracleLinux/OL7/latest/x86_64/getPackage/oracle-database-preinstall-19c-1.0-1.el7.x86_64.rpm 下载19c安装包 https://www.oracle.com/cn/database/technologies/oracle-database-software-downloads.html#19c 选择…

计算机网络-HTTP相关知识-HTTPS基础

HTTP与HTTPS的区别: HTTPS在TCP和HTTP网络层之间加入了SSL/TLS安全协议层。这个安全协议层可以对数据进行加密,确保数据在传输过程中的安全。HTTPS在TCP三次握手之后,还需进行SSL/TLS的握手过程。这个握手过程主要是为了在客户端和服务器之间…

超声波清洗机是干什么用的?2024年有用的超声波清洗机推荐

随着科技的不断进步,超声波清洗机已经成为了家庭和专业场所不可或缺的高效清洁工具。它利用超声波波动产生的微小气泡来清洁物品表面及细缝中的污渍,实现深层次的清洁效果。特别是对于眼镜这样的精密物品,定期进行深度清洁不仅能够确保视觉的…

【算法刷题day10】Leetcode:232.用栈实现队列、225. 用队列实现栈

文章目录 Leetcode 232.用栈实现队列解题思路代码总结 Leetcode 225. 用队列实现栈解题思路代码总结 stack、queue和deque对比 草稿图网站 java的Deque Leetcode 232.用栈实现队列 题目:232.用栈实现队列 解析:代码随想录解析 解题思路 一个栈负责进&a…

前端二维码生成工具小程序:构建营销神器的技术解析

摘要: 随着数字化营销的不断深入,二维码作为一种快速、便捷的信息传递方式,已经广泛应用于各个领域。本文旨在探讨如何通过前端技术构建一个功能丰富、操作简便的二维码生成工具小程序,为企业和个人提供高效的营销支持。 一、引言…

什么是过载

宇航员相关知识会涉及到过载,导弹相关知识也会涉及到过载,如导弹的过载加速度,什么是过载呢?博主从B站上看到一UP主讲的很好, 该up主视频链接: 过载是什么_哔哩哔哩_bilibili 内容截图如下:

钉钉服务端API报错 43008 参数需要multipart类型

钉钉服务端API报错 43008 参数需要multipart类型 problem 使用媒体文件上传接口,按照文档输入参数,结果返回报错 # 参数 {"access_token": "xxx""type": "image","media": "/Users/xxx/xxx/s…