RMSNorm原理及代码
在大模型中使用层归一化有如下几个因素:
- 改善网络稳定性
- 加速收敛速度
- 提高模型的泛化能力
批量归一化是对一个批次内的数据进行归一化
层归一化是对一个样本中的不同特征进行归一化
如下是LayerNorm与RMSNorm的公式
在LLaMA中使用RMSNorm替代LayerNorm,因为RMSNorm相比LayerNorm,不需要计算样本与均值的差(减少了计算量,加快了训练速度)
代码:
class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size)) # 以hidden_size大小的全1张量初始化self.variance_epsilon = eps # 给定一个很小的数,防止分母为0def forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype) # to(input_dtype)是为了保持数据类型
代码来源于:https://github.com/huggingface/transformers/tree/main/src/transformers/models/llama/modeling_llama.py