Batch Norm vs Layer Norm:为什么 Transformer 更适合用 Layer Norm?
1. Batch Norm 和 Layer Norm 的定义与作用
1.1 Batch Normalization (BN)
Batch Norm 是一种归一化方法,主要用于加速深层神经网络的训练。它在每个小批量(batch)中对输入的特征值进行归一化,保证特征的均值接近 0,方差接近 1,从而减小梯度消失和梯度爆炸的问题。
公式:对于第 ( i i i) 个特征:一般是一列
x ^ i = x i − μ B σ B 2 + ϵ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i=σB2+ϵxi−μB
其中:
- ( x i x_i xi) 是第 ( i i i) 个特征值
- ( μ B , σ B 2 \mu_B, \sigma_B^2 μB,σB2) 分别为该特征在当前 batch 的均值和方差
- ( ϵ \epsilon ϵ) 是平滑项,防止分母为 0
经过归一化后,BN 会通过两个可学习参数恢复原来的表达能力:
y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β
1.2 Layer Normalization (LN)
Layer Norm 是一种对 单一输入样本的特征维度 进行归一化的方法,不依赖 batch 维度。与 Batch Norm 不同,LN 计算的是当前样本在特征维度上的均值和方差。
公式:对于样本 ( x x x) 中第 ( i i i) 个特征:一般是一行
x ^ i = x i − μ L σ L 2 + ϵ \hat{x}_i = \frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}} x^i=σL2+ϵxi−μL
其中:
- ( μ L , σ L 2 \mu_L, \sigma_L^2 μL,σL2) 是当前样本所有特征的均值和方差
同样,LN 会使用可学习参数 ( γ , β \gamma, \beta γ,β) 恢复特征表达能力:
y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β
2. Batch Norm 和 Layer Norm 的主要区别
对比维度 | Batch Norm (BN) | Layer Norm (LN) |
---|---|---|
归一化范围 | Batch 中的每个特征 | 单个样本的所有特征 |
依赖 Batch 大小 | 依赖小批量数据的均值和方差,Batch 越小效果越差 | 不依赖 Batch 大小,适合单样本或小 Batch |
适用场景 | CNN 等图像任务,Batch 通常较大 | NLP 和 Transformer 等场景,输入序列和特征维度较多 |
稳定性 | 小 Batch 或在线学习时效果较差,可能导致波动 | 在小 Batch 和序列任务中表现稳定 |
3. 为什么 Transformer 中使用 Layer Norm?
Transformer 采用 Layer Norm 的原因可以从以下几个方面解释:
3.1 序列建模任务的特点
- Transformer 主要用于序列任务(如 NLP),输入通常为高维特征序列,长度可变。
- Layer Norm 在特征维度上归一化,适应任意长度的序列,而 Batch Norm 依赖批量大小,对序列任务不够灵活。
3.2 小批量训练的限制
- 在 NLP 任务中,由于长文本的存在,Batch Size 通常较小。Batch Norm 在小 Batch 下效果较差,易受样本均值和方差的噪声影响。
- Layer Norm 不依赖 Batch 维度,因此在小 Batch 训练中更稳定。
3.3 平稳性和梯度流动
- Transformer 使用自注意力机制,每层的输入依赖于上一层的输出。Layer Norm 在特征维度归一化,使梯度更新更加稳定。
- Batch Norm 会因为动态 Batch 均值和方差的变化导致梯度波动。
3.4 并行计算效率
- Layer Norm 的计算只涉及当前样本,适合 Transformer 的并行计算框架。
- Batch Norm 需要在小批量数据之间计算统计量,限制了并行效率。
4. 示例:数值模拟 Batch Norm 和 Layer Norm 的效果
我们通过代码模拟两种归一化方法的行为,观察它们在不同输入场景下的效果。
import torch
import torch.nn as nn# 定义数据
torch.manual_seed(42)
input_data = torch.randn(4, 5) # 4 个样本,每个样本 5 个特征
print("Input Data:\n", input_data)
#让input data =
# tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
# [ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
# [-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
# [ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]])# Batch Norm
batch_norm = nn.BatchNorm1d(5) # 对每个特征归一化
output_bn = batch_norm(input_data)
print("\nBatch Norm Output:\n", output_bn)# Layer Norm
layer_norm = nn.LayerNorm(5) # 对每个样本的特征维度归一化
output_ln = layer_norm(input_data)
print("\nLayer Norm Output:\n", output_ln)
Input:
Input Data:tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]])
输出对比
- Batch Norm 的归一化结果依赖于每列特征的 Batch 均值和方差。
- Layer Norm 的归一化结果仅依赖于每个样本的特征维度。
计算过程与解释
以下是给定代码中 Batch Norm 和 Layer Norm 的手动计算过程。我们分别以 Batch Norm 第 1 列 和 Layer Norm 第 1 行 为例,详细说明其工作原理。
1. Batch Norm 的计算过程
Batch Norm 操作:
- Batch Norm 是对每个特征(列)在整个 batch 中计算均值和方差,然后对每个特征进行归一化处理,公式如下:
BN ( x i , j ) = x i , j − μ j σ j 2 + ϵ ⋅ γ + β \text{BN}(x_{i,j}) = \frac{x_{i,j} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}} \cdot \gamma + \beta BN(xi,j)=σj2+ϵxi,j−μj⋅γ+β
其中:- ( x i , j x_{i,j} xi,j ) 是第 ( i i i) 个样本第 ( j j j) 个特征值。
- ( μ j \mu_j μj ) 是第 ( j j j) 列的均值,( σ j 2 \sigma_j^2 σj2) 是第 ( j j j) 列的方差。
- ( γ \gamma γ) 和 ( β \beta β) 是可学习参数(初始化为 1 和 0)。
以第 1 列为例(特征索引为 0):
输入数据第 1 列为:
Input 第 1 列 = [ 1.9269 , 1.0783 , − 0.4934 , 0.8599 ] \text{Input}_{\text{第 1 列}} = [1.9269, 1.0783, -0.4934, 0.8599] Input第 1 列=[1.9269,1.0783,−0.4934,0.8599]
-
计算均值:
μ 0 = 1.9269 + 1.0783 − 0.4934 + 0.8599 4 = 0.84293 \mu_0 = \frac{1.9269 + 1.0783 - 0.4934 + 0.8599}{4} = 0.84293 μ0=41.9269+1.0783−0.4934+0.8599=0.84293 -
计算方差:
σ 0 2 = ( 1.9269 − μ 0 ) 2 + ( 1.0783 − μ 0 ) 2 + ( − 0.4934 − μ 0 ) 2 + ( 0.8599 − μ 0 ) 2 4 = 0.71583 \sigma_0^2 = \frac{(1.9269 - \mu_0)^2 + (1.0783 - \mu_0)^2 + (-0.4934 - \mu_0)^2 + (0.8599 - \mu_0)^2}{4} = 0.71583 σ02=4(1.9269−μ0)2+(1.0783−μ0)2+(−0.4934−μ0)2+(0.8599−μ0)2=0.71583 -
归一化操作(假设 ( γ = 1 , β = 0 \gamma = 1, \beta = 0 γ=1,β=0)):
每个元素按以下公式计算:
BN ( x i , 0 ) = x i , 0 − μ 0 σ 0 2 + ϵ \text{BN}(x_{i,0}) = \frac{x_{i,0} - \mu_0}{\sqrt{\sigma_0^2 + \epsilon}} BN(xi,0)=σ02+ϵxi,0−μ0
取 ( ϵ = 1 0 − 5 \epsilon = 10^{-5} ϵ=10−5),结果如下:
BN ( 1.9269 ) = 1.9269 − 0.84293 0.71583 + 1 0 − 5 ≈ 1.2794 \text{BN}(1.9269) = \frac{1.9269 - 0.84293}{\sqrt{0.71583 + 10^{-5}}} \approx 1.2794 BN(1.9269)=0.71583+10−51.9269−0.84293≈1.2794
BN ( 1.0783 ) = 1.0783 − 0.84293 0.71583 + 1 0 − 5 ≈ 0.2774 \text{BN}(1.0783) = \frac{1.0783 - 0.84293}{\sqrt{0.71583 + 10^{-5}}} \approx 0.2774 BN(1.0783)=0.71583+10−51.0783−0.84293≈0.2774
BN ( − 0.4934 ) = − 0.4934 − 0.84293 0.71583 + 1 0 − 5 ≈ − 1.5784 \text{BN}(-0.4934) = \frac{-0.4934 - 0.84293}{\sqrt{0.71583 + 10^{-5}}} \approx -1.5784 BN(−0.4934)=0.71583+10−5−0.4934−0.84293≈−1.5784
BN ( 0.8599 ) = 0.8599 − 0.84293 0.71583 + 1 0 − 5 ≈ 0.0216 \text{BN}(0.8599) = \frac{0.8599 - 0.84293}{\sqrt{0.71583 + 10^{-5}}} \approx 0.0216 BN(0.8599)=0.71583+10−50.8599−0.84293≈0.0216
第 1 列的输出为:
BN 第 1 列输出 = [ 1.2794 , 0.2774 , − 1.5784 , 0.0216 ] \text{BN 第 1 列输出} = [1.2794, 0.2774, -1.5784, 0.0216] BN 第 1 列输出=[1.2794,0.2774,−1.5784,0.0216]
2. Layer Norm 的计算过程
Layer Norm 操作:
- Layer Norm 是对每个样本的所有特征(行)计算均值和方差,然后对整行归一化处理,公式如下:
LN ( x i , j ) = x i , j − μ i σ i 2 + ϵ ⋅ γ + β \text{LN}(x_{i,j}) = \frac{x_{i,j} - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} \cdot \gamma + \beta LN(xi,j)=σi2+ϵxi,j−μi⋅γ+β
其中:- ( x i , j x_{i,j} xi,j ) 是第 ( i i i) 个样本第 ( j j j) 个特征值。
- ( μ i \mu_i μi ) 是第 ( i i i) 行的均值,( σ i 2 \sigma_i^2 σi2) 是第 ( i i i) 行的方差。
- ( γ \gamma γ) 和 ( β \beta β) 是可学习参数(初始化为 1 和 0)。
以第 1 行为例(样本索引为 0):
输入数据第 1 行为:
Input 第 1 行 = [ 1.9269 , 1.4873 , 0.9007 , − 2.1055 , − 0.7581 ] \text{Input}_{\text{第 1 行}} = [1.9269, 1.4873, 0.9007, -2.1055, -0.7581] Input第 1 行=[1.9269,1.4873,0.9007,−2.1055,−0.7581]
-
计算均值:
μ 0 = 1.9269 + 1.4873 + 0.9007 − 2.1055 − 0.7581 5 = 0.29026 \mu_0 = \frac{1.9269 + 1.4873 + 0.9007 - 2.1055 - 0.7581}{5} = 0.29026 μ0=51.9269+1.4873+0.9007−2.1055−0.7581=0.29026 -
计算方差:
σ 0 2 = ( 1.9269 − μ 0 ) 2 + ( 1.4873 − μ 0 ) 2 + ( 0.9007 − μ 0 ) 2 + ( − 2.1055 − μ 0 ) 2 + ( − 0.7581 − μ 0 ) 2 5 = 2.01125 \sigma_0^2 = \frac{(1.9269 - \mu_0)^2 + (1.4873 - \mu_0)^2 + (0.9007 - \mu_0)^2 + (-2.1055 - \mu_0)^2 + (-0.7581 - \mu_0)^2}{5} = 2.01125 σ02=5(1.9269−μ0)2+(1.4873−μ0)2+(0.9007−μ0)2+(−2.1055−μ0)2+(−0.7581−μ0)2=2.01125 -
归一化操作(假设 ( γ = 1 , β = 0 \gamma = 1, \beta = 0 γ=1,β=0)):
每个元素按以下公式计算:
LN ( x 0 , j ) = x 0 , j − μ 0 σ 0 2 + ϵ \text{LN}(x_{0,j}) = \frac{x_{0,j} - \mu_0}{\sqrt{\sigma_0^2 + \epsilon}} LN(x0,j)=σ02+ϵx0,j−μ0
取 ( ϵ = 1 0 − 5 \epsilon = 10^{-5} ϵ=10−5),结果如下:
LN ( 1.9269 ) = 1.9269 − 0.29026 2.01125 + 1 0 − 5 ≈ 1.1526 \text{LN}(1.9269) = \frac{1.9269 - 0.29026}{\sqrt{2.01125 + 10^{-5}}} \approx 1.1526 LN(1.9269)=2.01125+10−51.9269−0.29026≈1.1526
LN ( 1.4873 ) = 1.4873 − 0.29026 2.01125 + 1 0 − 5 ≈ 0.8415 \text{LN}(1.4873) = \frac{1.4873 - 0.29026}{\sqrt{2.01125 + 10^{-5}}} \approx 0.8415 LN(1.4873)=2.01125+10−51.4873−0.29026≈0.8415
LN ( 0.9007 ) = 0.9007 − 0.29026 2.01125 + 1 0 − 5 ≈ 0.4318 \text{LN}(0.9007) = \frac{0.9007 - 0.29026}{\sqrt{2.01125 + 10^{-5}}} \approx 0.4318 LN(0.9007)=2.01125+10−50.9007−0.29026≈0.4318
LN ( − 2.1055 ) = − 2.1055 − 0.29026 2.01125 + 1 0 − 5 ≈ − 1.6822 \text{LN}(-2.1055) = \frac{-2.1055 - 0.29026}{\sqrt{2.01125 + 10^{-5}}} \approx -1.6822 LN(−2.1055)=2.01125+10−5−2.1055−0.29026≈−1.6822
LN ( − 0.7581 ) = − 0.7581 − 0.29026 2.01125 + 1 0 − 5 ≈ − 0.7437 \text{LN}(-0.7581) = \frac{-0.7581 - 0.29026}{\sqrt{2.01125 + 10^{-5}}} \approx -0.7437 LN(−0.7581)=2.01125+10−5−0.7581−0.29026≈−0.7437
第 1 行的输出为:
LN 第 1 行输出 = [ 1.1526 , 0.8415 , 0.4318 , − 1.6822 , − 0.7437 ] \text{LN 第 1 行输出} = [1.1526, 0.8415, 0.4318, -1.6822, -0.7437] LN 第 1 行输出=[1.1526,0.8415,0.4318,−1.6822,−0.7437]
更详细的例子可以参考笔者的另一篇博客:以[Today is great] [ How are you]两句话为例:学习Batch Norm和Layer Norm
总结
- Batch Norm 是对每列(特征维度)进行归一化操作,关注每个特征的分布。适合于计算机视觉任务中输入数据的特征相似性较强的情况。
- Layer Norm 是对每行(样本维度)进行归一化操作,关注样本自身的特征平衡。适合于 NLP 和序列建模中每个输入样本维度分布差异较大的场景。
5. Transformer 中 Layer Norm 的代码实现
在 Transformer 中,Layer Norm 通常应用于每个子层(子层归一化,SubLayer Norm),确保模型在深层架构中保持数值稳定。
import torch
import torch.nn as nnclass TransformerLayer(nn.Module):def __init__(self, d_model):super(TransformerLayer, self).__init__()self.self_attention = nn.MultiheadAttention(d_model, num_heads=8)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_model * 4),nn.ReLU(),nn.Linear(d_model * 4, d_model))self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(0.1)def forward(self, x):# Self-Attention + Residual Connection + Layer Normattn_output, _ = self.self_attention(x, x, x)x = self.norm1(x + self.dropout(attn_output))# Feed-Forward + Residual Connection + Layer Normff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return x# 测试 Transformer Layer
x = torch.randn(10, 16, 512) # 序列长度 10,Batch Size 16,特征维度 512
layer = TransformerLayer(d_model=512)
output = layer(x)
print("Transformer Output Shape:", output.shape)
6. Layer Norm 在 NLP 任务中的应用
在 NLP 中,Layer Norm 被广泛应用于 Transformer 模型,如 BERT、GPT 等。
典型的应用:
- BERT:在每个 Encoder 子层中使用 Layer Norm。
- GPT:在 Decoder 中结合残差连接使用 Layer Norm。
以下是 Layer Norm 在 NLP 任务中提升稳定性的原因:
- 输入多样性:句子长度和特征分布差异大,Layer Norm 可适应这些变化。
- 小批量训练:NLP 任务常用小 Batch,Layer Norm 保持稳定。
- 多头注意力机制:Layer Norm 确保特征维度的均衡分布,提升注意力权重的稳定性。
7. 总结
维度 | Batch Norm | Layer Norm |
---|---|---|
归一化范围 | Batch 内每列特征 | 每个样本的所有特征 |
适用场景 | 图像任务(CNN 等) | NLP 和序列建模(Transformer 等) |
对 Batch 大小的依赖 | 依赖 Batch 大小,Batch 越大越稳定 | 无需依赖 Batch 大小 |
数值稳定性 | 小 Batch 下梯度可能波动较大 | 无论 Batch 大小如何均稳定 |
Layer Norm 的灵活性和稳定性,使其成为 Transformer 和 NLP 任务的首选归一化方法,在深层序列模型中尤为重要。
后记
2024年12月14日17点01分于上海,在GPT4o大模型辅助下完成。