MOE原理解释及从零实现一个MOE(专家混合模型)

什么是混合模型(MOE)

一个MOE主要由两个关键点组成:

  • 一是将传统Transformer中的FFN(前馈网络层)替换为多个稀疏的专家层(Sparse MoE layers)。每个专家本身是一个独立的神经网络,实际应用中,这些专家通常是前馈网络 (FFN),但也可以是更复杂的网络结构。
  • 二是门控网络或路由:此部分用来决定输入的token分发给哪一个专家。

可能有对FFN(前馈网络层)不太熟悉的小伙伴可以看一下下面的代码及图例,很简单就是一个我们平时常见的结构。

class FeedForward(nn.Module):def __init__(self, dim_vector, dim_hidden, dropout=0.1):super().__init__()self.feedforward = nn.Sequential(nn.Linear(dim_vector, dim_hidden),nn.ReLU(),nn.Dropout(dropout),nn.Linear(dim_hidden, dim_vector))def forward(self, x):out = self.feedforward(x)return out

示意图如下:
在这里插入图片描述

从零实现一个MOE代码

完整的从零实现MOE代码已集成至git代码训练框架项目,项目包括一个每个人都可以以此为基础构建自己的开源大模型训练框架流程、支持主流模型使用deepspeed进行Lora、Qlora等训练、主流模型的chat template模版、以及一些tricks的从零实现模块。欢迎大家star 共同学习!:https://github.com/mst272/LLM-Dojo/blob/main/llm_tricks/moe/READEME.md

1. 创建一个专家模型

这一步也很简单了,其实就是一个多层感知机MLP。

class Expert(nn.Module):def __init__(self, n_embd):super().__init__()self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd),nn.ReLU(),nn.Linear(4 * n_embd, n_embd),nn.Dropout(dropout),)def forward(self, x):return self.net(x)

2. 创建TopKrouter

即创建MOE的路由部分。
假设我们定义了4个专家,路由取前2名专家。接收注意力层的输出作为输入,即将输入从(Batch size,Tokens,n_embed)的形状(2,4,32)投影到对应于(Batch size,Tokens,num_experts)的形状(2,4,4),其中num_experts是专家网络的计数。其中返回的indices可以理解为对于每个token的4个专家来说,选的两个专家的序号索引。
代码如下:

# 这里我们假设定义n_embed为32, num_experts=4, top_k=2class TopkRouter(nn.Module):def __init__(self, n_embed, num_experts, top_k):super(TopkRouter, self).__init__()self.top_k = top_kself.linear =nn.Linear(n_embed, num_experts)def forward(self, mh_output):logits = self.linear(mh_output)    # (2,4,32) ---> (2,4,4)# 获取前K大的值和索引,沿列。top_k_logits, indices = logits.topk(self.top_k, dim=-1)# 创建一个形状和logits相同全'-inf'矩阵,即(2,4,4)zeros = torch.full_like(logits, float('-inf'))# 按照索引和值填充上述zeros矩阵sparse_logits = zeros.scatter(-1, indices, top_k_logits)# 对其进行softmax,未被填充的位置会为0router_output = F.softmax(sparse_logits, dim=-1)return router_output, indices

看完代码之后配合整体流程图将会更清晰:
在这里插入图片描述
更清晰的图示如下,每个字代表一个token:
在这里插入图片描述

3. 添加noisy噪声

从本质上讲,我们不希望所有token都发送给同一组“受青睐”的expert。需要一个良好平衡,因此,将标准正态噪声添加到来自门控线性层的logits。
代码对比2中的代码只改动了几行,非常的简单。

class NoisyTopkRouter(nn.Module):def __init__(self, n_embed, num_experts, top_k):super(NoisyTopkRouter, self).__init__()self.top_k = top_kself.topkroute_linear = nn.Linear(n_embed, num_experts)# add noiseself.noise_linear =nn.Linear(n_embed, num_experts)def forward(self, mh_output):# mh_ouput is the output tensor from multihead self attention blocklogits = self.topkroute_linear(mh_output)#Noise logitsnoise_logits = self.noise_linear(mh_output)#Adding scaled unit gaussian noise to the logitsnoise = torch.randn_like(logits)*F.softplus(noise_logits)noisy_logits = logits + noisetop_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)zeros = torch.full_like(noisy_logits, float('-inf'))sparse_logits = zeros.scatter(-1, indices, top_k_logits)router_output = F.softmax(sparse_logits, dim=-1)return router_output, indices

4. 构建完整的稀疏MOE module

前面的操作主要是获取了router分发的结果,获取到这些结果后我们就可以将router乘给对应的token。这种选择性加权乘法最终构成了稀疏MOE。
代码部分如下所示:

class SparseMoE(nn.Module):def __init__(self, n_embed, num_experts, top_k):super(SparseMoE, self).__init__()self.router = NoisyTopkRouter(n_embed, num_experts, top_k)self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])self.top_k = top_kdef forward(self, x):# 1. 输入进入router得到两个输出gating_output, indices = self.router(x)# 2.初始化全零矩阵,后续叠加为最终结果final_output = torch.zeros_like(x)# 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平flat_x = x.view(-1, x.size(-1))flat_gating_output = gating_output.view(-1, gating_output.size(-1))# 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权for i, expert in enumerate(self.experts):# 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2expert_mask = (indices == i).any(dim=-1)# 5. 展平操作flat_mask = expert_mask.view(-1)# 如果当前专家是任意一个token的前top2if flat_mask.any():# 6. 得到该专家对哪几个token起作用后,选取token的维度表示expert_input = flat_x[flat_mask]# 7. 将token输入expert得到输出expert_output = expert(expert_input)# 8. 计算当前专家对于有作用的token的权重分数gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)# 9. 将expert输出乘上权重分数weighted_output = expert_output * gating_scores# 10. 循环进行做种的结果叠加final_output[expert_mask] += weighted_output.squeeze(1)return final_output

其中的一些讲解都在注释中了,特别注意的是该部分的逻辑是以专家为单位遍历每个专家,抽取每个专家所对应的tokens。结合上述代码注释中的序号,可以参考下面tensor流向图,可以完整清晰的理解该内容。
在这里插入图片描述
在这里插入图片描述

5. 将MOE与transformer结合

这一部分主要就是将上述所做的工作与常规的transformer层结合,即用moe替代MLP层。

class Block(nn.Module):"""Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """def __init__(self, n_embed, n_head, num_experts, top_k):super().__init__()head_size = n_embed // n_headself.sa = MultiHeadAttention(n_head, head_size)self.smoe = SparseMoE(n_embed, num_experts, top_k)self.ln1 = nn.LayerNorm(n_embed)self.ln2 = nn.LayerNorm(n_embed)def forward(self, x):x = x + self.sa(self.ln1(x))x = x + self.smoe(self.ln2(x))return x

总结

最终我们得到了上述block,算是一个完整的模块了,并从头到尾将MOE的实现细节都讲解了一遍,理解原理后我们就可以对当前的一些主流模型进行moe魔改等操作了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/22798.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[论文笔记]Mistral 7B

引言 今天带来大名鼎鼎的Mistral 7B的论文笔记。 作者推出了Mistral 7B,这是一个70亿参数的语言模型。Mistral 7B在所有评估基准中表现优于最佳的13B开源模型(Llama 2),并且在推理、数学和代码生成方面胜过最佳发布的34B模型(Ll…

odoo qweb template小结

QWeb QWeb是一个基于XML的模板引擎,可用于生成HTML片段和页面。它使用XML格式来定义模板。QWeb通过在模板中添加特定的标记,来指示模板中的数据和逻辑部分。使用QWeb,你可以创建各种不同的模板,例如列表视图,表单视图和报告等。QWeb支持标准的HTML标记和控制结构,如if语…

D435相机结合Yolo V8识别出目标物体,并转点云出抓取位姿。

最近项目上需要完成整个识别、定位、到最后的抓取流程。 分享一下,通过使用D435相机并结合Yolo V8识别出目标物体后,抠取出目标物体部分的有效深度图,最后将前景物体部分的RGB D435相机结合Yolo V8识别出目标物体,并转点云出抓取位…

从高海拔到严寒季的测量作业更要「快准稳」,怎么实现?

西藏那曲海拔4500米公路勘测项目赶工期 “必须要保障在西藏那曲地区承接的公路勘测项目赶工期需求,海拔高达4500米、网络通讯不足、部分范围存在无网以及地基信号覆盖可能不足的情况,需要能满足环境和项目需求的专业RTK设备紧急送到。” 客户的一个电话…

做外贸是否需要代运营?

相信很多做外贸的小伙伴或者公司都有这样的一个困扰,尤其是做SEO以及平台的公司,会很纠结要不要将公司的运营承包出去。 而之所以有这样的困扰,一部分是公司的业务员可能并不擅长运营,或者是业务员抽不出时间去管理运营这块。 而…

映射网络驱动器自动断开的解决方法

如果将驱动器映射到网络共享,映射的驱动器可能会在定期处于非活动状态后断开连接,并且 Windows 资源管理器可能会在映射驱动器的图标上显示红色 X。,出现此行为的原因是,系统可以在指定的超时期限后断开空闲连接, (默认…

PWA缓存策略区别NetworkOnly/CacheFirst/CacheOnly/NetworkFirst/StaleWhileRevalidate

现在来看看 Workbox 提供的缓存策略,主要有这几种: cache-first, cache-only, network-first, network-only, stale-while-revalidate 在前面看到,实例化的时候会给 workbox 挂载一个 Strategies 的实例。提供上面一系列的缓存策略&…

半导体制造中的压缩气体及其高压扩散器如何选择 北京中邦兴业

了解高压扩散器 高压扩散器(HPD)对于保持压缩气体样品中颗粒计数的精度至关重要。它们充当颗粒计数器和压缩气体管线之间的纽带,在气体进入颗粒计数器的样品入口时使其扩散。这确保了压力得到控制,以防止对颗粒计数器样品室的敏感…

uniapp学习(001 前期介绍)

零基础入门uniapp Vue3组合式API版本到咸虾米壁纸项目实战,开发打包微信小程序、抖音小程序、H5、安卓APP客户端等 总时长 23:40:00 共116P 此文章包含第1p-第p10的内容 简介 目录结构 效果 打包成小程序 配置开发者工具 打开安全按钮 使用uniapp的内置组件…

3DMAX一键虚线图形插件DashedShape使用方法

3DMAX一键虚线图形插件使用方法 3dMax一键虚线图形插件,允许从场景中拾取的样条线创建虚线形状。该工具使你能够创建完全自定义的填充图案,为线段设置不同的材质ID,并在视口中进行方便的预览。 【版本要求】 3dMax 2012 – 2025(…

数据结构与算法笔记:基础篇 - 数组:为什么数组都是从0开始编号

概述 提到数组,大家应该都不陌生。每一种编程语言基本都会有数组这种数据类型。不过,它不仅仅是一种编程语言中的数据类型,还是一种基础的数据结构。尽管数组看起来非常简单,但是我估计很多人并没有理解这个数据结构的精髓。 在…

AB测试实战

AB测试实战 1、AB测试介绍🐾 很多网站/APP的首页都会挂一张头图(Banner),用来展示重要信息,头图是否吸引人会对公司的营收带来重大影响,一家寿险公司Humana设计了如下三张头图,现在需要决定使用哪一张放到首页&#x…

FastDFS分布式文件系统

一、概述 FastDFS是一款由国人余庆开发的轻量级开源分布式文件系统,它对文件进行管理,功能包括:文件存储、文件同步、文件访问(文件上传、文件下载)等,主要解决大容量文件存储和高并发访问问题&#xff0c…

jenkins应用2-freestyle-job

1.jenkins应用 1.jenkins构建的流程 1.使用git参数化构建,用标签区分版本 2.git 拉取gitlab远程仓库代码 3.maven打包项目 4.sonarqube经行代码质量检测 5.自定义制作镜像发送到远程仓库harbor 6.在远程服务器上拉取代码启动容器 这个是构建的整个过程和步骤…

保姆级教程:Redis 主从复制原理及集群搭建

😄作者简介: 小曾同学.com,一个致力于测试开发的博主⛽️,主要职责:测试开发、CI/CD 如果文章知识点有错误的地方,还请大家指正,让我们一起学习,一起进步。 😊 座右铭:不…

线程池的工作原理

文章目录 一、应用场景二、工作原理三、主要函数 一、应用场景 传统并发变成的缺陷: 1.创建和销毁线程上花费的时间和消耗的系统资源,甚至可能要比花在处理实际的用户请求的时间和资源要多得多 2. 活动的线程需要消耗系统资源,如果启动太多&…

26、matlab多项式曲线拟合:polyfit ()函数

1、polyfit 多项式曲线拟合 语法 语法:p polyfit(x,y,n) 返回次数为 n 的多项式 p(x) 的系数,该阶数是 y 中数据的最佳拟合(基于最小二乘指标)。 语法:[p,S] polyfit(x,y,n) 还返回一个结构体 S 语法:[…

优化 mac 储存空间的方法 只需一招为你的苹果电脑提速

在职场中,许多人都对苹果电脑情有独钟。苹果电脑以其简洁美观的设计、流畅稳定的性能以及出色的用户体验,成为了众多职场人士的得力助手。无论是处理文档、制作演示文稿,还是进行创意设计等工作,苹果电脑都能展现出其独特的优势&a…

微信小程序公众号二合一分销商城源码系统 基于PHP+MySQL组合开发的 可多商户商家入驻 带完整的安装代码包以及搭建教程

系统概述 微信小程序公众号二合一分销商城源码系统,是基于PHPMySQL组合开发的一款高效、稳定的电子商务平台解决方案。该系统创新性地将微信公众号与小程序的功能进行了深度整合,为商家提供了一个功能齐全、易于管理的分销商城系统。通过此系统&#xf…

Vue3+vant 带你实现常见的历史记录的业务功能

前言 大部分小伙伴不管是开发PC端还是H5移动端,都会遇到历史搜索的功能。对用户的历史记录进行增删查可以是接口,也可以是前端用缓存实现,一般用浏览器缓存实现的比较多,这篇文章就来教你如何用LocalStorage对历史记录数据的存储、…