ref:https://huggingface.co/blog/zh/moe#%E7%94%A8router-z-loss%E7%A8%B3%E5%AE%9A%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83
MoEs and Transformers
Transformer 类模型明确表明,增加参数数量可以提高性能,因此谷歌使用 GShard 尝试将 Transformer 模型的参数量扩展到超过 6000 亿并不令人惊讶。
GShard 将在编码器和解码器中的每个前馈网络 (FFN) 层中的替换为使用 Top-2 门控的混合专家模型 (MoE) 层。下图展示了编码器部分的结构。这种架构对于大规模计算非常有效: 当扩展到多个设备时,MoE 层在不同设备间共享,而其他所有层则在每个设备上复制。我们将在 “让 MoE 起飞” 部分对这一点进行更详细的讨论。
为了保持负载平衡和训练效率,GShard 的作者除了引入了上一节中讨论的类似辅助损失外,还引入了一些关键变化:
随机路由: 在 Top-2 设置中,我们始终选择排名最高的专家,但第二个专家是根据其权重比例随机选择的。
专家容量: 我们可以设定一个阈值,定义一个专家能处理多少令牌。如果两个专家的容量都达到上限,令牌就会溢出,并通过残差连接传递到下一层,或在某些情况下被完全丢弃。专家容量是 MoE 中最重要的概念之一。为什么需要专家容量呢?因为所有张量的形状在编译时是静态确定的,我们无法提前知道多少令牌会分配给每个专家,因此需要一个固定的容量因子。
GShard 的工作对适用于 MoE 的并行计算模式也做出了重要贡献,但这些内容的讨论超出了这篇博客的范围。
注意: 在推理过程中,只有部分专家被激活。同时,有些计算过程是共享的,例如自注意力 (self-attention) 机制,它适用于所有令牌。这就解释了为什么我们可以使用相当于 12B 稠密模型的计算资源来运行一个包含 8 个专家的 47B 模型。如果我们采用 Top-2 门控,模型会使用高达 14B 的参数。但是,由于自注意力操作 (专家间共享) 的存在,实际上模型运行时使用的参数数量是 12B。
Switch Transformers
尽管混合专家模型 (MoE) 显示出了很大的潜力,但它们在训练和微调过程中存在稳定性问题。Switch Transformers 是一项非常激动人心的工作,它深入研究了这些话题。作者甚至在 Hugging Face 上发布了一个 1.6 万亿参数的 MoE,拥有 2048 个专家,你可以使用 transformers 库来运行它。Switch Transformers 实现了与 T5-XXL 相比 4 倍的预训练速度提升。
就像在 GShard 中一样,作者用混合专家模型 (MoE) 层替换了前馈网络 (FFN) 层。Switch Transformers 提出了一个 Switch Transformer 层,它接收两个输入 (两个不同的令牌) 并拥有四个专家。
与最初使用至少两个专家的想法相反,Switch Transformers 采用了简化的单专家策略。这种方法的效果包括:
减少门控网络 (路由) 计算负担
每个专家的批量大小至少可以减半
降低通信成本
保持模型质量
Switch Transformers 采用了编码器 - 解码器的架构,实现了与 T5 类似的混合专家模型 (MoE) 版本。GLaM 这篇工作探索了如何使用仅为原来 1/3 的计算资源 (因为 MoE 模型在训练时需要的计算量较少,从而能够显著降低碳足迹) 来训练与 GPT-3 质量相匹配的模型来提高这些模型的规模。作者专注于仅解码器 (decoder-only) 的模型以及少样本和单样本评估,而不是微调。他们使用了 Top-2 路由和更大的容量因子。此外,他们探讨了将容量因子作为一个动态度量,根据训练和评估期间所使用的计算量进行调整。
用 Router z-loss 稳定模型训练
之前讨论的平衡损失可能会导致稳定性问题。我们可以使用许多方法来稳定稀疏模型的训练,但这可能会牺牲模型质量。例如,引入 dropout 可以提高稳定性,但会导致模型质量下降。另一方面,增加更多的乘法分量可以提高质量,但会降低模型稳定性。
ST-MoE 引入的 Router z-loss 在保持了模型性能的同时显著提升了训练的稳定性。这种损失机制通过惩罚门控网络输入的较大 logits 来起作用,目的是促使数值的绝对大小保持较小,这样可以有效减少计算中的舍入误差。这一点对于那些依赖指数函数进行计算的门控网络尤其重要。
专家的数量对预训练有何影响?
增加更多专家可以提升处理样本的效率和加速模型的运算速度,但这些优势随着专家数量的增加而递减 (尤其是当专家数量达到 256 或 512 之后更为明显)。同时,这也意味着在推理过程中,需要更多的显存来加载整个模型。值得注意的是,Switch Transformers 的研究表明,其在大规模模型中的特性在小规模模型下也同样适用,即便是每层仅包含 2、4 或 8 个专家。
对于开源的混合专家模型 (MoE),你可以关注下面这些:
Switch Transformers (Google): 基于 T5 的 MoE 集合,专家数量从 8 名到 2048 名。最大的模型有 1.6 万亿个参数。
NLLB MoE (Meta): NLLB 翻译模型的一个 MoE 变体。
OpenMoE: 社区对基于 Llama 的模型的 MoE 尝试。
Mixtral 8x7B (Mistral): 一个性能超越了 Llama 2 70B 的高质量混合专家模型,并且具有更快的推理速度。此外,还发布了一个经过指令微调的模型。有关更多信息,可以在 Mistral 的 公告博客文章 中了解。
REF:https://github.com/kyegomez/SwitchTransformers/blob/main/switch_transformers/model.py
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from zeta.nn import FeedForward, MultiQueryAttentionclass SwitchGate(nn.Module):"""SwitchGate module for MoE (Mixture of Experts) model.Args:dim (int): Input dimension.num_experts (int): Number of experts.capacity_factor (float, optional): Capacity factor for sparsity. Defaults to 1.0.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments."""def __init__(self,dim,num_experts: int,capacity_factor: float = 1.0,epsilon: float = 1e-6,*args,**kwargs,):super().__init__()self.dim = dimself.num_experts = num_expertsself.capacity_factor = capacity_factorself.epsilon = epsilonself.w_gate = nn.Linear(dim, num_experts)def forward(self, x: Tensor, use_aux_loss=False):"""Forward pass of the SwitchGate module.Args:x (Tensor): Input tensor.Returns:Tensor: Gate scores."""# Compute gate scoresgate_scores = F.softmax(self.w_gate(x), dim=-1)# Determine the top-1 expert for each tokencapacity = int(self.capacity_factor * x.size(0))top_k_scores, top_k_indices = gate_scores.topk(1, dim=-1)# Mask to enforce sparsitymask = torch.zeros_like(gate_scores).scatter_(1, top_k_indices, 1)# Combine gating scores with the maskmasked_gate_scores = gate_scores * mask# Denominatorsdenominators = (masked_gate_scores.sum(0, keepdim=True) + self.epsilon)# Norm gate scores to sum to the capacitygate_scores = (masked_gate_scores / denominators) * capacityif use_aux_loss:load = gate_scores.sum(0) # Sum over all examplesimportance = gate_scores.sum(1) # Sum over all experts# Aux loss is mean suqared difference between load and importanceloss = ((load - importance) ** 2).mean()return gate_scores, lossreturn gate_scores, Noneclass SwitchMoE(nn.Module):"""A module that implements the Switched Mixture of Experts (MoE) architecture.Args:dim (int): The input dimension.hidden_dim (int): The hidden dimension of the feedforward network.output_dim (int): The output dimension.num_experts (int): The number of experts in the MoE.capacity_factor (float, optional): The capacity factor that controls the capacity of the MoE. Defaults to 1.0.mult (int, optional): The multiplier for the hidden dimension of the feedforward network. Defaults to 4.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.Attributes:dim (int): The input dimension.hidden_dim (int): The hidden dimension of the feedforward network.output_dim (int): The output dimension.num_experts (int): The number of experts in the MoE.capacity_factor (float): The capacity factor that controls the capacity of the MoE.mult (int): The multiplier for the hidden dimension of the feedforward network.experts (nn.ModuleList): The list of feedforward networks representing the experts.gate (SwitchGate): The switch gate module."""def __init__(self,dim: int,hidden_dim: int,output_dim: int,num_experts: int,capacity_factor: float = 1.0,mult: int = 4,use_aux_loss: bool = False,*args,**kwargs,):super().__init__()self.dim = dimself.hidden_dim = hidden_dimself.output_dim = output_dimself.num_experts = num_expertsself.capacity_factor = capacity_factorself.mult = multself.use_aux_loss = use_aux_lossself.experts = nn.ModuleList([FeedForward(dim, dim, mult, *args, **kwargs)for _ in range(num_experts)])self.gate = SwitchGate(dim,num_experts,capacity_factor,)def forward(self, x: Tensor):"""Forward pass of the SwitchMoE module.Args:x (Tensor): The input tensor.Returns:Tensor: The output tensor of the MoE."""# (batch_size, seq_len, num_experts)gate_scores, loss = self.gate(x, use_aux_loss=self.use_aux_loss)# Dispatch to expertsexpert_outputs = [expert(x) for expert in self.experts]# Check if any gate scores are nan and handleif torch.isnan(gate_scores).any():print("NaN in gate scores")gate_scores[torch.isnan(gate_scores)] = 0# Stack and weight outputsstacked_expert_outputs = torch.stack(expert_outputs, dim=-1) # (batch_size, seq_len, output_dim, num_experts)if torch.isnan(stacked_expert_outputs).any():stacked_expert_outputs[torch.isnan(stacked_expert_outputs)] = 0# Combine expert outputs and gating scoresmoe_output = torch.sum(gate_scores.unsqueeze(-2) * stacked_expert_outputs, dim=-1)return moe_output, lossclass SwitchTransformerBlock(nn.Module):"""SwitchTransformerBlock is a module that represents a single block of the Switch Transformer model.Args:dim (int): The input dimension of the block.heads (int): The number of attention heads.dim_head (int): The dimension of each attention head.mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.dropout (float, optional): The dropout rate. Defaults to 0.1.depth (int, optional): The number of layers in the block. Defaults to 12.num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 6.*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.Attributes:dim (int): The input dimension of the block.heads (int): The number of attention heads.dim_head (int): The dimension of each attention head.mult (int): The multiplier for the hidden dimension in the feed-forward network.dropout (float): The dropout rate.attn_layers (nn.ModuleList): List of MultiQueryAttention layers.ffn_layers (nn.ModuleList): List of SwitchMoE layers.Examples:>>> block = SwitchTransformerBlock(dim=512, heads=8, dim_head=64)>>> x = torch.randn(1, 10, 512)>>> out = block(x)>>> out.shape"""def __init__(self,dim: int,heads: int,dim_head: int,mult: int = 4,dropout: float = 0.1,num_experts: int = 3,*args,**kwargs,):super().__init__()self.dim = dimself.heads = headsself.dim_head = dim_headself.mult = multself.dropout = dropoutself.attn = MultiQueryAttention(dim, heads, qk_ln=True * args, **kwargs)self.ffn = SwitchMoE(dim, dim * mult, dim, num_experts, *args, **kwargs)self.add_norm = nn.LayerNorm(dim)def forward(self, x: Tensor):"""Forward pass of the SwitchTransformerBlock.Args:x (Tensor): The input tensor.Returns:Tensor: The output tensor."""resi = xx, _, _ = self.attn(x)x = x + resix = self.add_norm(x)add_normed = x##### MoE #####x, _ = self.ffn(x)x = x + add_normedx = self.add_norm(x)return xclass SwitchTransformer(nn.Module):"""SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts.Args:num_tokens (int): The number of tokens in the input vocabulary.dim (int): The dimensionality of the token embeddings and hidden states.heads (int): The number of attention heads.dim_head (int, optional): The dimensionality of each attention head. Defaults to 64.mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.dropout (float, optional): The dropout rate. Defaults to 0.1.num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3.*args: Additional positional arguments.**kwargs: Additional keyword arguments."""def __init__(self,num_tokens: int,dim: int,heads: int,dim_head: int = 64,mult: int = 4,dropout: float = 0.1,num_experts: int = 3,depth: int = 4,*args,**kwargs,):super().__init__()self.num_tokens = num_tokensself.dim = dimself.heads = headsself.dim_head = dim_headself.mult = multself.dropout = dropoutself.num_experts = num_expertsself.depth = depthself.embedding = nn.Embedding(num_tokens, dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(SwitchTransformerBlock(dim,heads,dim_head,mult,dropout,num_experts,*args,**kwargs,))self.to_out = nn.Sequential(nn.Softmax(dim=-1),nn.LayerNorm(dim),nn.Linear(dim, num_tokens),)def forward(self, x: Tensor) -> Tensor:"""Forward pass of the SwitchTransformer.Args:x (Tensor): The input tensor of shape (batch_size, sequence_length).Returns:Tensor: The output tensor of shape (batch_size, sequence_length, num_tokens)."""# Embed tokens through embedding layerx = self.embedding(x)# Pass through the transformer block with MoE, it's in modulelistfor layer in self.layers:x = layer(x)# Project to output tokensx = self.to_out(x)return x