MoEs and Transformers 笔记

ref:https://huggingface.co/blog/zh/moe#%E7%94%A8router-z-loss%E7%A8%B3%E5%AE%9A%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83
MoEs and Transformers

Transformer 类模型明确表明,增加参数数量可以提高性能,因此谷歌使用 GShard 尝试将 Transformer 模型的参数量扩展到超过 6000 亿并不令人惊讶。

GShard 将在编码器和解码器中的每个前馈网络 (FFN) 层中的替换为使用 Top-2 门控的混合专家模型 (MoE) 层。下图展示了编码器部分的结构。这种架构对于大规模计算非常有效: 当扩展到多个设备时,MoE 层在不同设备间共享,而其他所有层则在每个设备上复制。我们将在 “让 MoE 起飞” 部分对这一点进行更详细的讨论。
在这里插入图片描述

为了保持负载平衡和训练效率,GShard 的作者除了引入了上一节中讨论的类似辅助损失外,还引入了一些关键变化:

随机路由: 在 Top-2 设置中,我们始终选择排名最高的专家,但第二个专家是根据其权重比例随机选择的。
专家容量: 我们可以设定一个阈值,定义一个专家能处理多少令牌。如果两个专家的容量都达到上限,令牌就会溢出,并通过残差连接传递到下一层,或在某些情况下被完全丢弃。专家容量是 MoE 中最重要的概念之一。为什么需要专家容量呢?因为所有张量的形状在编译时是静态确定的,我们无法提前知道多少令牌会分配给每个专家,因此需要一个固定的容量因子。
GShard 的工作对适用于 MoE 的并行计算模式也做出了重要贡献,但这些内容的讨论超出了这篇博客的范围。

注意: 在推理过程中,只有部分专家被激活。同时,有些计算过程是共享的,例如自注意力 (self-attention) 机制,它适用于所有令牌。这就解释了为什么我们可以使用相当于 12B 稠密模型的计算资源来运行一个包含 8 个专家的 47B 模型。如果我们采用 Top-2 门控,模型会使用高达 14B 的参数。但是,由于自注意力操作 (专家间共享) 的存在,实际上模型运行时使用的参数数量是 12B。

Switch Transformers
尽管混合专家模型 (MoE) 显示出了很大的潜力,但它们在训练和微调过程中存在稳定性问题。Switch Transformers 是一项非常激动人心的工作,它深入研究了这些话题。作者甚至在 Hugging Face 上发布了一个 1.6 万亿参数的 MoE,拥有 2048 个专家,你可以使用 transformers 库来运行它。Switch Transformers 实现了与 T5-XXL 相比 4 倍的预训练速度提升。
就像在 GShard 中一样,作者用混合专家模型 (MoE) 层替换了前馈网络 (FFN) 层。Switch Transformers 提出了一个 Switch Transformer 层,它接收两个输入 (两个不同的令牌) 并拥有四个专家。
在这里插入图片描述
与最初使用至少两个专家的想法相反,Switch Transformers 采用了简化的单专家策略。这种方法的效果包括:

减少门控网络 (路由) 计算负担
每个专家的批量大小至少可以减半
降低通信成本
保持模型质量

Switch Transformers 采用了编码器 - 解码器的架构,实现了与 T5 类似的混合专家模型 (MoE) 版本。GLaM 这篇工作探索了如何使用仅为原来 1/3 的计算资源 (因为 MoE 模型在训练时需要的计算量较少,从而能够显著降低碳足迹) 来训练与 GPT-3 质量相匹配的模型来提高这些模型的规模。作者专注于仅解码器 (decoder-only) 的模型以及少样本和单样本评估,而不是微调。他们使用了 Top-2 路由和更大的容量因子。此外,他们探讨了将容量因子作为一个动态度量,根据训练和评估期间所使用的计算量进行调整。

用 Router z-loss 稳定模型训练
之前讨论的平衡损失可能会导致稳定性问题。我们可以使用许多方法来稳定稀疏模型的训练,但这可能会牺牲模型质量。例如,引入 dropout 可以提高稳定性,但会导致模型质量下降。另一方面,增加更多的乘法分量可以提高质量,但会降低模型稳定性。

ST-MoE 引入的 Router z-loss 在保持了模型性能的同时显著提升了训练的稳定性。这种损失机制通过惩罚门控网络输入的较大 logits 来起作用,目的是促使数值的绝对大小保持较小,这样可以有效减少计算中的舍入误差。这一点对于那些依赖指数函数进行计算的门控网络尤其重要。

专家的数量对预训练有何影响?
增加更多专家可以提升处理样本的效率和加速模型的运算速度,但这些优势随着专家数量的增加而递减 (尤其是当专家数量达到 256 或 512 之后更为明显)。同时,这也意味着在推理过程中,需要更多的显存来加载整个模型。值得注意的是,Switch Transformers 的研究表明,其在大规模模型中的特性在小规模模型下也同样适用,即便是每层仅包含 2、4 或 8 个专家。

对于开源的混合专家模型 (MoE),你可以关注下面这些:

Switch Transformers (Google): 基于 T5 的 MoE 集合,专家数量从 8 名到 2048 名。最大的模型有 1.6 万亿个参数。
NLLB MoE (Meta): NLLB 翻译模型的一个 MoE 变体。
OpenMoE: 社区对基于 Llama 的模型的 MoE 尝试。
Mixtral 8x7B (Mistral): 一个性能超越了 Llama 2 70B 的高质量混合专家模型,并且具有更快的推理速度。此外,还发布了一个经过指令微调的模型。有关更多信息,可以在 Mistral 的 公告博客文章 中了解。

REF:https://github.com/kyegomez/SwitchTransformers/blob/main/switch_transformers/model.py

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from zeta.nn import FeedForward, MultiQueryAttentionclass SwitchGate(nn.Module):"""SwitchGate module for MoE (Mixture of Experts) model.Args:dim (int): Input dimension.num_experts (int): Number of experts.capacity_factor (float, optional): Capacity factor for sparsity. Defaults to 1.0.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments."""def __init__(self,dim,num_experts: int,capacity_factor: float = 1.0,epsilon: float = 1e-6,*args,**kwargs,):super().__init__()self.dim = dimself.num_experts = num_expertsself.capacity_factor = capacity_factorself.epsilon = epsilonself.w_gate = nn.Linear(dim, num_experts)def forward(self, x: Tensor, use_aux_loss=False):"""Forward pass of the SwitchGate module.Args:x (Tensor): Input tensor.Returns:Tensor: Gate scores."""# Compute gate scoresgate_scores = F.softmax(self.w_gate(x), dim=-1)# Determine the top-1 expert for each tokencapacity = int(self.capacity_factor * x.size(0))top_k_scores, top_k_indices = gate_scores.topk(1, dim=-1)# Mask to enforce sparsitymask = torch.zeros_like(gate_scores).scatter_(1, top_k_indices, 1)# Combine gating scores with the maskmasked_gate_scores = gate_scores * mask# Denominatorsdenominators = (masked_gate_scores.sum(0, keepdim=True) + self.epsilon)# Norm gate scores to sum to the capacitygate_scores = (masked_gate_scores / denominators) * capacityif use_aux_loss:load = gate_scores.sum(0)  # Sum over all examplesimportance = gate_scores.sum(1)  # Sum over all experts# Aux loss is mean suqared difference between load and importanceloss = ((load - importance) ** 2).mean()return gate_scores, lossreturn gate_scores, Noneclass SwitchMoE(nn.Module):"""A module that implements the Switched Mixture of Experts (MoE) architecture.Args:dim (int): The input dimension.hidden_dim (int): The hidden dimension of the feedforward network.output_dim (int): The output dimension.num_experts (int): The number of experts in the MoE.capacity_factor (float, optional): The capacity factor that controls the capacity of the MoE. Defaults to 1.0.mult (int, optional): The multiplier for the hidden dimension of the feedforward network. Defaults to 4.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.Attributes:dim (int): The input dimension.hidden_dim (int): The hidden dimension of the feedforward network.output_dim (int): The output dimension.num_experts (int): The number of experts in the MoE.capacity_factor (float): The capacity factor that controls the capacity of the MoE.mult (int): The multiplier for the hidden dimension of the feedforward network.experts (nn.ModuleList): The list of feedforward networks representing the experts.gate (SwitchGate): The switch gate module."""def __init__(self,dim: int,hidden_dim: int,output_dim: int,num_experts: int,capacity_factor: float = 1.0,mult: int = 4,use_aux_loss: bool = False,*args,**kwargs,):super().__init__()self.dim = dimself.hidden_dim = hidden_dimself.output_dim = output_dimself.num_experts = num_expertsself.capacity_factor = capacity_factorself.mult = multself.use_aux_loss = use_aux_lossself.experts = nn.ModuleList([FeedForward(dim, dim, mult, *args, **kwargs)for _ in range(num_experts)])self.gate = SwitchGate(dim,num_experts,capacity_factor,)def forward(self, x: Tensor):"""Forward pass of the SwitchMoE module.Args:x (Tensor): The input tensor.Returns:Tensor: The output tensor of the MoE."""# (batch_size, seq_len, num_experts)gate_scores, loss = self.gate(x, use_aux_loss=self.use_aux_loss)# Dispatch to expertsexpert_outputs = [expert(x) for expert in self.experts]# Check if any gate scores are nan and handleif torch.isnan(gate_scores).any():print("NaN in gate scores")gate_scores[torch.isnan(gate_scores)] = 0# Stack and weight outputsstacked_expert_outputs = torch.stack(expert_outputs, dim=-1)  # (batch_size, seq_len, output_dim, num_experts)if torch.isnan(stacked_expert_outputs).any():stacked_expert_outputs[torch.isnan(stacked_expert_outputs)] = 0# Combine expert outputs and gating scoresmoe_output = torch.sum(gate_scores.unsqueeze(-2) * stacked_expert_outputs, dim=-1)return moe_output, lossclass SwitchTransformerBlock(nn.Module):"""SwitchTransformerBlock is a module that represents a single block of the Switch Transformer model.Args:dim (int): The input dimension of the block.heads (int): The number of attention heads.dim_head (int): The dimension of each attention head.mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.dropout (float, optional): The dropout rate. Defaults to 0.1.depth (int, optional): The number of layers in the block. Defaults to 12.num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 6.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.Attributes:dim (int): The input dimension of the block.heads (int): The number of attention heads.dim_head (int): The dimension of each attention head.mult (int): The multiplier for the hidden dimension in the feed-forward network.dropout (float): The dropout rate.attn_layers (nn.ModuleList): List of MultiQueryAttention layers.ffn_layers (nn.ModuleList): List of SwitchMoE layers.Examples:>>> block = SwitchTransformerBlock(dim=512, heads=8, dim_head=64)>>> x = torch.randn(1, 10, 512)>>> out = block(x)>>> out.shape"""def __init__(self,dim: int,heads: int,dim_head: int,mult: int = 4,dropout: float = 0.1,num_experts: int = 3,*args,**kwargs,):super().__init__()self.dim = dimself.heads = headsself.dim_head = dim_headself.mult = multself.dropout = dropoutself.attn = MultiQueryAttention(dim, heads, qk_ln=True * args, **kwargs)self.ffn = SwitchMoE(dim, dim * mult, dim, num_experts, *args, **kwargs)self.add_norm = nn.LayerNorm(dim)def forward(self, x: Tensor):"""Forward pass of the SwitchTransformerBlock.Args:x (Tensor): The input tensor.Returns:Tensor: The output tensor."""resi = xx, _, _ = self.attn(x)x = x + resix = self.add_norm(x)add_normed = x##### MoE #####x, _ = self.ffn(x)x = x + add_normedx = self.add_norm(x)return xclass SwitchTransformer(nn.Module):"""SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts.Args:num_tokens (int): The number of tokens in the input vocabulary.dim (int): The dimensionality of the token embeddings and hidden states.heads (int): The number of attention heads.dim_head (int, optional): The dimensionality of each attention head. Defaults to 64.mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.dropout (float, optional): The dropout rate. Defaults to 0.1.num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3.*args: Additional positional arguments.**kwargs: Additional keyword arguments."""def __init__(self,num_tokens: int,dim: int,heads: int,dim_head: int = 64,mult: int = 4,dropout: float = 0.1,num_experts: int = 3,depth: int = 4,*args,**kwargs,):super().__init__()self.num_tokens = num_tokensself.dim = dimself.heads = headsself.dim_head = dim_headself.mult = multself.dropout = dropoutself.num_experts = num_expertsself.depth = depthself.embedding = nn.Embedding(num_tokens, dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(SwitchTransformerBlock(dim,heads,dim_head,mult,dropout,num_experts,*args,**kwargs,))self.to_out = nn.Sequential(nn.Softmax(dim=-1),nn.LayerNorm(dim),nn.Linear(dim, num_tokens),)def forward(self, x: Tensor) -> Tensor:"""Forward pass of the SwitchTransformer.Args:x (Tensor): The input tensor of shape (batch_size, sequence_length).Returns:Tensor: The output tensor of shape (batch_size, sequence_length, num_tokens)."""# Embed tokens through embedding layerx = self.embedding(x)# Pass through the transformer block with MoE, it's in modulelistfor layer in self.layers:x = layer(x)# Project to output tokensx = self.to_out(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/66528.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ubuntu为Docker配置代理

终端代理 我们平常在ubuntu终端中使用curl或git命令时,往往会很慢。 所以,首先需要给ubuntu终端环境添加代理。 查看自身那个软件的端口号,我这里是7890。 sudo gedit ~/.bashrcexport http_proxyhttp://localhost:7890 export https_pr…

【安卓开发】【Android Studio】项目构建失败提示【Could not read metadata.bin】解决方法

一、问题说明 在Android Studio中开发安卓项目时,项目构建失败,提示如下: Could not read workspace data from xxx/xxx/(某个目录,和gradle有关):could not read ...metadata.bin&#xff08…

EXCEL: (二) 常用图表

10. 图表 134-添加.删除图表元素 图表很少是一个单独的整体,而是由十几种元素/对象拼凑出来的。 学习图表就是学习当中各类元素的插删改。 ①图表中主要元素的定义 图表上的一个颜色就是一个系列。 横轴是分类轴,将每个系列都分为几类。 ②选中图…

【AniGS】论文阅读

笔记目录 1. 基本信息2. 理解(个人初步理解,随时更改)1. 基本信息 题目:AniGS: Animatable Gaussian Avatar from a Single Image with Inconsistent Gaussian Reconstruction时间:2024.12发表:arxiv机构:Alibaba等作者:Lingteng Qiu等链接直达:Project关键词:img to…

ubuntu20.04 在线安装postgresql 扩展postgis

基础配置 /etc/apt/sources.list # 添加pg官方基础配置deb http://apt.postgresql.org/pub/repos/apt/ focal-pgdg main# 添加ubuntu官方依赖(防止下载依赖错误)deb http://archive.ubuntu.com/ubuntu/ focal main restricted universe multiverse de…

【单片机】实现一个简单的ADC滤波器

实现一个 ADC的滤波器,PT1 滤波器(也称为一阶低通滤波器),用于对输入信号进行滤波处理。 typedef struct PT1FilterSettings PT1FilterSettings; struct PT1FilterSettings {//! last Filter output valueuint32_t filtValOld;//…

Kafka-go语言一命速通

记录 命令(终端操作kafka) # 验证kafka是否启动 ps -ef | grep kafka # ps -ef 命令用于显示所有正在运行的进程的详细信息 lsof -i :9092# 启动kafka brew services start zookeeper brew services start kafka# 创建topic kafka-topics --create --topic test --p…

sys.dm_exec_connections:查询与 SQL Server 实例建立的连接有关的信息以及每个连接的详细信息(客户端ip)

文章目录 引言I 基于dm_exec_connections查询客户端ip权限物理联接时间范围dm_exec_connections表see also: 监视SQL Server 内存使用量资源信号灯 DMV sys.dm_exec_query_resource_semaphores( 确定查询执行内存的等待)引言 查询历史数据库客户端ip应用场景: 安全分析缺乏…

阿里云发现后门webshell,怎么处理,怎么解决?

当收到如下阿里云通知邮件时,大部分管理员都会心里一惊吧!出现Webshell,大概是网站被入侵了。 尊敬的 xxxaliyun.com: 云盾云安全中心检测到您的服务器:47.108.x.xx(xx机)出现了紧急安全事件…

【大模型】百度千帆大模型对接LangChain使用详解

目录 一、前言 二、LangChain架构与核心组件 2.1 LangChain 核心架构 2.2 LangChain 核心组件 三、环境准备 3.1 前置准备 3.1.1 创建应用并获取apikey 3.1.2 开通付费功能 3.2 获取LangChain文档 3.3 安装LangChain依赖包 四、百度千帆大模型对接 LangChain 4.1 LL…

maven如何从外部导包

1.找到你项目的文件位置,将外部要导入的包复制粘贴进你当前要导入的项目下。 2.从你的项目目录下选中要导入的包的pom文件即可导包成功 注意一定是选中对应的pom文件 导入成功之后对应的pom.xml文件就会被点亮

如何轻松反转C# List<T>中的元素顺序

在C#中&#xff0c;有多种方法可以反转 List<T> 的元素顺序。以下是几种常见的方法&#xff1a; 方法一&#xff1a;使用 List<T>.Reverse 方法 List<T> 类提供了一个内置的 Reverse 方法&#xff0c;可以就地反转列表中的元素顺序。 using System; using…

graylog配置日志关键字邮件Email告警

文章目录 前言一、配置邮箱报警二、邮件告警配置1.设置邮箱告警通知2.设置告警事件 三、告警结果示例 前言 在TO/B、TO/C项目的告警链路中&#xff0c;端口告警、服务器基础资源告警、接口告警等不同类型的告警都扮演着重要角色。这些告警能够帮助团队及时发现潜在的系统问题&…

GraphQL:强大的API查询语言

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

集合卡尔曼滤波ENKF的学习笔记

集合卡尔曼滤波ENKF的学习笔记 什么是数据同化&#xff1f;天气模型预测是25C&#xff08;这叫"背景场"&#xff09;&#xff0c;实际观测是23C&#xff08;这叫"观测值"&#xff09;&#xff0c;数据同化就是把这两个信息合理结合。 # 最简单的组合方式&…

Spring AI ectorStore

Spring AI中的VectorStore是一种用于存储和检索高维向量数据的数据库或存储解决方案&#xff0c;它在AI应用中扮演着至关重要的角色。以下是对Spring AI VectorStore的详细解析&#xff1a; 一、VectorStore的基本概念 定义&#xff1a;VectorStore特别适用于处理那些经过嵌入…

DeepSeek-V3 通俗详解:从诞生到优势,以及与 GPT-4o 的对比

1. DeepSeek 的前世今生 1.1 什么是 DeepSeek&#xff1f; DeepSeek 是一家专注于人工智能技术研发的公司&#xff0c;致力于打造高性能、低成本的 AI 模型。它的目标是让 AI 技术更加普惠&#xff0c;让更多人能够用上强大的 AI 工具。 1.2 DeepSeek-V3 的诞生 DeepSeek-V…

UI自动化测试保姆级教程--pytest详解(精简易懂)

欢迎来到啊妮莫的学习小屋 别让过去的悲伤&#xff0c;毁掉当下的快乐一《借东西的小人阿莉埃蒂》 简介 pytest是一个用于Python的测试框架, 支持简单的单元测试和复杂的功能测试. 和Python自带的UnitTest框架类似, 但是相比于UnitTest更加简洁, 效率更高. 特点 非常容易上手…

微信小程序实现长按录音,点击播放等功能,CSS实现语音录制动画效果

有一个需求需要在微信小程序上实现一个长按时进行语音录制&#xff0c;录制时间最大为60秒&#xff0c;录制完成后&#xff0c;可点击播放&#xff0c;播放时再次点击停止播放&#xff0c;可以反复录制&#xff0c;新录制的语音把之前的语音覆盖掉&#xff0c;也可以主动长按删…

Clojure语言的数据库编程

Clojure语言的数据库编程 引言 在当今社会&#xff0c;数据的处理和管理已经成为一个不可或缺的部分。无论是互联网应用、企业系统还是移动应用&#xff0c;都需要与数据库进行频繁的交互。因此&#xff0c;选择一种合适的编程语言和相应的库来进行数据库编程显得尤为重要。C…