下面是详细实现的LayerNorm和RMSNorm代码,并附有详细注释。
LayerNorm 实现
LayerNorm的主要思想是对每个样本的每一层进行归一化。具体的实现如下:
import torch
import torch.nn as nnclass LayerNorm(nn.Module):def __init__(self, d_model, eps=1e-6):"""初始化 LayerNorm 模块。参数:d_model: 输入张量的最后一个维度的大小。eps: 防止除零错误的一个小常数。"""super(LayerNorm, self).__init__()self.gamma = nn.Parameter(torch.ones(d_model)) # 缩放参数self.beta = nn.Parameter(torch.zeros(d_model)) # 平移参数self.eps = epsdef forward(self, x):"""前向传播函数。参数:x: 输入张量,形状为 [batch_size, seq_len, d_model]。返回:归一化后的张量,形状同输入张量。"""mean = x.mean(-1, keepdim=True) # 计算均值std = x.std(-1, keepdim=True) # 计算标准差x_norm = (x - mean) / (std + self.eps) # 标准化return self.gamma * x_norm + self.beta # 缩放和平移# 测试 LayerNorm
x = torch.randn(2, 3, 4) # 示例输入张量
layer_norm = LayerNorm(4)
output = layer_norm(x)
print("LayerNorm Output:")
print(output)
RMSNorm 实现
RMSNorm是另一种归一化方法,它使用均方根(RMS)而不是标准差来进行归一化。实现如下:
class RMSNorm(nn.Module):def __init__(self, d_model, eps=1e-6):"""初始化 RMSNorm 模块。参数:d_model: 输入张量的最后一个维度的大小。eps: 防止除零错误的一个小常数。"""super(RMSNorm, self).__init__()self.gamma = nn.Parameter(torch.ones(d_model)) # 缩放参数self.eps = epsdef forward(self, x):"""前向传播函数。参数:x: 输入张量,形状为 [batch_size, seq_len, d_model]。返回:归一化后的张量,形状同输入张量。"""rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) # 计算均方根x_norm = x / rms # 标准化return self.gamma * x_norm # 缩放# 测试 RMSNorm
x = torch.randn(2, 3, 4) # 示例输入张量
rms_norm = RMSNorm(4)
output = rms_norm(x)
print("RMSNorm Output:")
print(output)
代码解释
-
LayerNorm:
- 初始化时,创建可学习的缩放参数
gamma
和平移参数beta
,并设定一个小常数eps
以防止除零错误。 - 前向传播时,计算输入张量在最后一个维度上的均值和标准差。
- 用均值和标准差对输入张量进行标准化,然后使用
gamma
和beta
对标准化后的张量进行缩放和平移。
- 初始化时,创建可学习的缩放参数
-
RMSNorm:
- 初始化时,创建可学习的缩放参数
gamma
,并设定一个小常数eps
以防止除零错误。 - 前向传播时,计算输入张量在最后一个维度上的均方根(RMS)。
- 用RMS对输入张量进行标准化,然后使用
gamma
对标准化后的张量进行缩放。
- 初始化时,创建可学习的缩放参数
通过这种方式,LayerNorm和RMSNorm都可以有效地对输入张量进行归一化,从而提高模型的训练稳定性和性能。