- 参考:
- BN究竟起了什么作用?一个闭门造车的分析
- 《动手学深度学习》7.5 节
-
深度学习中,归一化是常用的稳定训练的手段,CV 中常用 Batch Norm; Transformer 类模型中常用 layer norm,而 RMSNorm 是近期很流行的 LaMMa 模型使用的标准化方法,它是 Layer Norm 的一个变体
-
值得注意的是,这里所谓的归一化严格讲应该称为
标准化Standardization
,它描述一种把样本调整到均值为 0,方差为 1 的缩放平移操作。归一化、标准化、正则化等术语常常被混用,可以看 标准化、归一化概念梳理(附代码) 这篇文章理清 -
详细讨论前,先粗略看一下 Batch Norm 和 Layer Norm 的区别
- BatchNorm是对整个 batch 样本内的每个特征做归一化,这消除了不同特征之间的大小关系,但是保留了不同样本间的大小关系。BatchNorm 适用于 CV 领域,这时输入尺寸为 b × c × h × w b\times c\times h\times w b×c×h×w (批量大小x通道x长x宽),图像的每个通道 c c c 看作一个特征,BN 可以把各通道特征图的数量级调整到差不多,同时保持不同图片相同通道特征图间的相对大小关系
- LayerNorm是对每个样本的所有特征做归一化,这消除了不同样本间的大小关系,但是保留了一个样本内不同特征之间的大小关系。LayerNorm 适用于 NLP 领域,这时输入尺寸为 b × l × d b\times l\times d b×l×d (批量大小x序列长度x嵌入维度),如下图所示
注意这时长 l l l 的 token 序列中,每个 token 对应一个长为 d d d 的特征向量,LayerNorm 会对各个 token 执行 l l l 次归一化计算,保留每个 token d d d 维嵌入内部的相对大小关系,同时拉近了不同 token 对应特征向量间的距离。与之相比,BN 会消除 d d d 维特征向量各维度之间的大小关系,破坏了 token 的特征(以下第 2 节会进一步说明这一点)
文章目录
- 1. Batch Normalization
- 1.1 原理
- 2. Layer Normalization
- 3. RMSNorm
1. Batch Normalization
1.1 原理
-
BN 对同一 batch 内同一通道的所有数据进行归一化,设输入的 batch data 为 x \pmb{x} x,BN 运算如下
B N ( x ) = γ ⊙ x − μ ^ B σ ^ B + β . \mathrm{BN}(\mathbf{x})=\boldsymbol{\gamma} \odot \frac{\mathbf{x}-\hat{\boldsymbol{\mu}}_{\mathcal{B}}}{\hat{\boldsymbol{\sigma}}_{\mathcal{B}}}+\boldsymbol{\beta} . BN(x)=γ⊙σ^Bx−μ^B+β. 其中 ⊙ \odot ⊙ 表示按位置乘, γ \pmb{\gamma} γ 和 β \pmb{\beta} β 是拉伸参数scale
和偏移参数shift
,这两个参数的 size 和特征维数相同,代表着把第 i i i 个特征的 batch 分布的均值和方差移动到 β i , γ i \beta^i, \gamma^i βi,γi, γ \pmb{\gamma} γ 和 β \pmb{\beta} β 是需要与其他模型参数一起学习的参数。 μ ^ B \hat{\boldsymbol{\mu}}_{\mathcal{B}} μ^B 和 σ ^ B \hat{\boldsymbol{\sigma}}_{\mathcal{B}} σ^B 表示 batch data 中各特征的均值和方差,如下计算
μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x σ ^ B 2 = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ \begin{aligned} \hat{\boldsymbol{\mu}}_{\mathcal{B}}&=\frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x} \\ \hat{\boldsymbol{\sigma}}_{\mathcal{B}}^{2}&=\frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}}\left(\mathbf{x}-\hat{\boldsymbol{\mu}}_{\mathcal{B}}\right)^{2}+\epsilon \end{aligned} μ^Bσ^B2=∣B∣1x∈B∑x=∣B∣1x∈B∑(x−μ^B)2+ϵ 注意我们在方差估计值中添加一个小的常量 ϵ \epsilon ϵ,以确保我们永远不会尝试除以零 -
注意一些细节
-
在 MLP 中应用 BN 时,均值和方差的计算发生在各个特征维度上。此时输入数据形式通常为 x ∈ R b × n \pmb{x}\in\mathbb{R}^{b\times n} x∈Rb×n,其中 b = ∣ B ∣ b=|\mathcal{B}| b=∣B∣ 为 batch size, n n n 为特征维度,有 μ ^ B , σ ^ B 2 , γ , β ∈ R 1 × n \hat{\boldsymbol{\mu}}_{\mathcal{B}},\hat{\boldsymbol{\sigma}}_{\mathcal{B}}^{2},\pmb{\gamma},\pmb{\beta} \in \mathbb{R}^{1\times n} μ^B,σ^B2,γ,β∈R1×n
-
在 CNN 中应用 BN 时,均值和方差的计算发生在各个通道上。此时输入数据形式通常为 x ∈ R b × c × h × w \pmb{x}\in\mathbb{R}^{b\times c\times h\times w} x∈Rb×c×h×w,其中 b = ∣ B ∣ b=|\mathcal{B}| b=∣B∣ 为 batch size, c , h , w c, h, w c,h,w 分别为为通道数量和图像长宽尺寸,有 μ ^ B , σ ^ B 2 , γ , β ∈ R 1 × c × 1 × 1 \hat{\boldsymbol{\mu}}_{\mathcal{B}},\hat{\boldsymbol{\sigma}}_{\mathcal{B}}^{2},\pmb{\gamma},\pmb{\beta} \in \mathbb{R}^{1\times c \times 1\times 1} μ^B,σ^B2,γ,β∈R1×c×1×1,如下图所示
-
BN 层在”训练模式“(通过小批量统计数据规范化)和“预测模式”(通过数据集统计规范化)中的功能不同。 训练过程中,我们无法得知使用整个数据集来估计平均值和方差,所以只能根据每个小批次的平均值和方差不断训练模型;预测模式下,可以根据整个数据集精确计算批量规范化所需的平均值和方差
-
-
BatchNorm是一种在深度学习训练中广泛使用的归一化技术,有很多好处,包括正则化效应、减少过拟合、减少对权重初始值的依赖、允许使用更高的学习率等
- 一方面,BN 使每一层隐藏值分布主动居中,并将它们重新调整为学习到的最佳均值和方差,这种操作可能将参数的量级进行了统一,因此直觉上往往被认为可以使优化更加平滑
- 另一方面,BN 有效性的科学性解释一度存在争议。15 年提出BN的论文声称 BN 减小了所谓的
内部协变量偏移internal covariate shift
,因此可以提高模型性能,但其分析中假设了每层隐变量值都服从某种正态分布,这个假设过强了,很多后续研究指出了其问题。18 年的论文 How Does Batch Normalization Help Optimization? 认为 BN 的主要作用是使得整个损失函数的 landscape 更为平滑,从而使得我们可以更平稳地进行训练。相关分析可以参考苏神的博文
-
示例代码参考自《动手学深度学习》7.5 节,适用于全连接层和卷积层,训练过程中使用滑动平均法计算 batch 数据的均值和方差;评估过程中使用最新的均值和方差结果
class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2: # 全连接层shape = (1, num_features)else: # 卷积层shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def batch_norm(self, X, gamma, beta, moving_mean, moving_var, eps, momentum):if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0) # (num_features,)var = ((X - mean) ** 2).mean(dim=0) # (num_features,)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。mean = X.mean(dim=(0, 2, 3), keepdim=True) # (1,num_features,1,1) 保持X的形状,以便后面可以做广播运算var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True) # (1,num_features,1,1)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta # 缩放和移位return Y, moving_mean.data, moving_var.datadef forward(self, X):# 如果X不在内存上,将moving_mean和moving_var,复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = self.batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y
2. Layer Normalization
-
LN 主要用于 NLP 领域,它对每个 token 的特征向量进行归一化计算。设某个 token 的特征向量为 x ∈ R d \pmb{x}\in \mathbb{R}^d x∈Rd,LN 运算如下
L N ( x ) = γ ⊙ x − μ ^ σ ^ + β . \mathrm{LN}(\mathbf{x})=\boldsymbol{\gamma} \odot \frac{\mathbf{x}-\hat{\boldsymbol{\mu}}}{\hat{\boldsymbol{\sigma}}}+\boldsymbol{\beta} . LN(x)=γ⊙σ^x−μ^+β. 其中 ⊙ \odot ⊙ 表示按位置乘, γ , β ∈ R d \pmb{\gamma}, \pmb{\beta}\in \mathbb{R}^d γ,β∈Rd 和 是拉伸参数scale
和偏移参数shift
,代表着把第 i i i 个特征的 batch 分布的均值和方差移动到 β i , γ i \beta^i, \gamma^i βi,γi, γ \pmb{\gamma} γ 和 β \pmb{\beta} β 是需要与其他模型参数一起学习的参数。 μ ^ \hat{\boldsymbol{\mu}} μ^ 和 σ ^ \hat{\boldsymbol{\sigma}} σ^ 表示特征向量所有元素的均值和方差,如下计算
μ ^ = 1 d ∑ x i ∈ x x i σ ^ 2 = 1 d ∑ x i ∈ x ( x i − μ ^ ) 2 + ϵ \begin{aligned} \hat{\boldsymbol{\mu}}&=\frac{1}{d} \sum_{x^i \in \mathbf{x}} x^i \\ \hat{\boldsymbol{\sigma}}^{2}&=\frac{1}{d} \sum_{x^i \in \mathbf{x}}\left(x^i-\hat{\boldsymbol{\mu}}\right)^{2}+\epsilon \end{aligned} μ^σ^2=d1xi∈x∑xi=d1xi∈x∑(xi−μ^)2+ϵ 注意我们在方差估计值中添加一个小的常量 ϵ \epsilon ϵ,以确保我们永远不会尝试除以零 -
给定一个长 l l l 的句子,LN 要进行 l l l 次归一化计算,之后对每个特征维度施加统一的拉伸和偏移,如下图所示
-
为什么 LN 比 BN 更适用于 Transformer 类模型呢,这是因为 transformer 模型是基于相似度的,把序列中的每个 token 的特征向量进行归一化有利于模型学习语义,第一步调整均值方差时,相当于对把各个 token 的特征向量缩放到统一的尺度,第二步施加 γ , β \pmb{\gamma, \beta} γ,β 时,相当于对所有 token 的特征向量进行了统一的 transfer,这不会破坏 token 特征向量间的相对角度,因此不会破坏学到的语义信息。与之相对的,BN 沿着特征维度进行归一化,这时对序列中各个 token 施加的 transfer 是不同的,破坏了 token 特征向量间的相对角度关系
-
Transformer 类模型中,LayerNorm 层有两种放置方式
Pre Norm: x t + 1 = x t + F t ( Norm ( x t ) ) Post Norm: x t + 1 = Norm ( x t + F t ( x t ) ) \text{Pre Norm:} \quad \boldsymbol{x}_{t+1}=\boldsymbol{x}_{t}+F_{t}\left(\operatorname{Norm}\left(\boldsymbol{x}_{t}\right)\right) \\ \text{Post Norm:} \quad \boldsymbol{x}_{t+1}=\operatorname{Norm}\left(\boldsymbol{x}_{t}+F_{t}\left(\boldsymbol{x}_{t}\right)\right) Pre Norm:xt+1=xt+Ft(Norm(xt))Post Norm:xt+1=Norm(xt+Ft(xt)) 如下图所示
目前比较明确的结论是:同一设置之下,Pre Norm结构往往更容易训练,但最终效果通常不如Post Norm
- Pre Norm 更容易训练好理解,因为它的恒等路径更突出
- Pre Norm 中多层叠加的结果更多是增加宽度而不是深度,层数越多,这个层就越“虚”,这是因为 Pre Norm 结构无形地增加了模型的宽度而降低了模型的深度,而我们知道深度通常比宽度更重要,所以是无形之中的降低深度导致最终效果变差了。而 Post Norm 刚刚相反,它每Norm一次就削弱一次恒等分支的权重,所以 Post Norm 反而是更突出残差分支的,因此Post Norm中的层数更加有分量,起到了作用,一旦训练好之后效果更优。详细说明参考 为什么Pre Norm的效果不如Post Norm?
-
过去 BERT 主流的时代往往使用 Post Norm,现在 GPT 时代模型规模都很大,因此大多用 Pre Norm 来稳定训练
3. RMSNorm
- RMSNorm 是 LayerNorm 的一个简单变体,来自 2019 年的论文 Root Mean Square Layer Normalization,被 T5 和当前流行 lamma 模型所使用。其提出的动机是 LayerNorm 运算量比较大,所提出的RMSNorm 性能和 LayerNorm 相当,但是可以节省7%到64%的运算
- RMSNorm和LayerNorm的主要区别在于RMSNorm不需要同时计算均值和方差两个统计量,而只需要计算均方根 Root Mean Square 这一个统计量,公式如下
RMSNorm ( x ) = γ ⊙ x RMS ( x ) where RMS ( x ) = 1 d ∑ x i ∈ x x i 2 + ϵ \text{RMSNorm}(\pmb{x})=\boldsymbol{\gamma} \odot\frac{\pmb{x}}{\operatorname{RMS}(x)} \quad \text{where \quad}\operatorname{RMS}(x)=\sqrt{\frac{1}{d} \sum_{x^i \in \mathbf{x}} x_{i}^{2} + \epsilon} RMSNorm(x)=γ⊙RMS(x)xwhere RMS(x)=d1xi∈x∑xi2+ϵ - 论文 Do Transformer Modifications Transfer Across Implementations and Applications? 中做了比较充分的对比实验,显示出RMS Norm的优越性。一个直观的猜测是,计算均值所代表的 center 操作类似于全连接层的 bias 项,储存到的是关于预训练任务的一种先验分布信息,而把这种先验分布信息直接储存在模型中,反而可能会导致模型的迁移能力下降
- 下面给出 Transformer Lamma 源码中实现的 RMSNorm
class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype)