详解三种常用标准化 Batch Norm Layer Norm RMSNorm

  • 参考:
    • BN究竟起了什么作用?一个闭门造车的分析
    • 《动手学深度学习》7.5 节

  • 深度学习中,归一化是常用的稳定训练的手段,CV 中常用 Batch Norm; Transformer 类模型中常用 layer norm,而 RMSNorm 是近期很流行的 LaMMa 模型使用的标准化方法,它是 Layer Norm 的一个变体

  • 值得注意的是,这里所谓的归一化严格讲应该称为 标准化Standardization ,它描述一种把样本调整到均值为 0,方差为 1 的缩放平移操作。归一化、标准化、正则化等术语常常被混用,可以看 标准化、归一化概念梳理(附代码) 这篇文章理清

  • 详细讨论前,先粗略看一下 Batch Norm 和 Layer Norm 的区别
    在这里插入图片描述

    1. BatchNorm是对整个 batch 样本内的每个特征做归一化,这消除了不同特征之间的大小关系,但是保留了不同样本间的大小关系。BatchNorm 适用于 CV 领域,这时输入尺寸为 b × c × h × w b\times c\times h\times w b×c×h×w (批量大小x通道x长x宽),图像的每个通道 c c c 看作一个特征,BN 可以把各通道特征图的数量级调整到差不多,同时保持不同图片相同通道特征图间的相对大小关系
    2. LayerNorm是对每个样本的所有特征做归一化,这消除了不同样本间的大小关系,但是保留了一个样本内不同特征之间的大小关系。LayerNorm 适用于 NLP 领域,这时输入尺寸为 b × l × d b\times l\times d b×l×d (批量大小x序列长度x嵌入维度),如下图所示
      在这里插入图片描述

    注意这时长 l l l 的 token 序列中,每个 token 对应一个长为 d d d 的特征向量,LayerNorm 会对各个 token 执行 l l l 次归一化计算,保留每个 token d d d 维嵌入内部的相对大小关系,同时拉近了不同 token 对应特征向量间的距离。与之相比,BN 会消除 d d d 维特征向量各维度之间的大小关系,破坏了 token 的特征(以下第 2 节会进一步说明这一点)

文章目录

  • 1. Batch Normalization
    • 1.1 原理
  • 2. Layer Normalization
  • 3. RMSNorm

1. Batch Normalization

1.1 原理

  • BN 对同一 batch 内同一通道的所有数据进行归一化,设输入的 batch data 为 x \pmb{x} x,BN 运算如下
    B N ( x ) = γ ⊙ x − μ ^ B σ ^ B + β . \mathrm{BN}(\mathbf{x})=\boldsymbol{\gamma} \odot \frac{\mathbf{x}-\hat{\boldsymbol{\mu}}_{\mathcal{B}}}{\hat{\boldsymbol{\sigma}}_{\mathcal{B}}}+\boldsymbol{\beta} . BN(x)=γσ^Bxμ^B+β. 其中 ⊙ \odot 表示按位置乘, γ \pmb{\gamma} γ β \pmb{\beta} β拉伸参数scale偏移参数shift,这两个参数的 size 和特征维数相同,代表着把第 i i i 个特征的 batch 分布的均值和方差移动到 β i , γ i \beta^i, \gamma^i βi,γi γ \pmb{\gamma} γ β \pmb{\beta} β 是需要与其他模型参数一起学习的参数 μ ^ B \hat{\boldsymbol{\mu}}_{\mathcal{B}} μ^B σ ^ B \hat{\boldsymbol{\sigma}}_{\mathcal{B}} σ^B 表示 batch data 中各特征的均值和方差,如下计算
    μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x σ ^ B 2 = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ \begin{aligned} \hat{\boldsymbol{\mu}}_{\mathcal{B}}&=\frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x} \\ \hat{\boldsymbol{\sigma}}_{\mathcal{B}}^{2}&=\frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}}\left(\mathbf{x}-\hat{\boldsymbol{\mu}}_{\mathcal{B}}\right)^{2}+\epsilon \end{aligned} μ^Bσ^B2=B1xBx=B1xB(xμ^B)2+ϵ 注意我们在方差估计值中添加一个小的常量 ϵ \epsilon ϵ,以确保我们永远不会尝试除以零

  • 注意一些细节

    1. 在 MLP 中应用 BN 时,均值和方差的计算发生在各个特征维度上。此时输入数据形式通常为 x ∈ R b × n \pmb{x}\in\mathbb{R}^{b\times n} xRb×n,其中 b = ∣ B ∣ b=|\mathcal{B}| b=B 为 batch size, n n n 为特征维度,有 μ ^ B , σ ^ B 2 , γ , β ∈ R 1 × n \hat{\boldsymbol{\mu}}_{\mathcal{B}},\hat{\boldsymbol{\sigma}}_{\mathcal{B}}^{2},\pmb{\gamma},\pmb{\beta} \in \mathbb{R}^{1\times n} μ^B,σ^B2,γ,βR1×n

    2. 在 CNN 中应用 BN 时,均值和方差的计算发生在各个通道上。此时输入数据形式通常为 x ∈ R b × c × h × w \pmb{x}\in\mathbb{R}^{b\times c\times h\times w} xRb×c×h×w,其中 b = ∣ B ∣ b=|\mathcal{B}| b=B 为 batch size, c , h , w c, h, w c,h,w 分别为为通道数量和图像长宽尺寸,有 μ ^ B , σ ^ B 2 , γ , β ∈ R 1 × c × 1 × 1 \hat{\boldsymbol{\mu}}_{\mathcal{B}},\hat{\boldsymbol{\sigma}}_{\mathcal{B}}^{2},\pmb{\gamma},\pmb{\beta} \in \mathbb{R}^{1\times c \times 1\times 1} μ^B,σ^B2,γ,βR1×c×1×1,如下图所示
      在这里插入图片描述

    3. BN 层在”训练模式“(通过小批量统计数据规范化)和“预测模式”(通过数据集统计规范化)中的功能不同。 训练过程中,我们无法得知使用整个数据集来估计平均值和方差,所以只能根据每个小批次的平均值和方差不断训练模型;预测模式下,可以根据整个数据集精确计算批量规范化所需的平均值和方差

  • BatchNorm是一种在深度学习训练中广泛使用的归一化技术,有很多好处,包括正则化效应、减少过拟合、减少对权重初始值的依赖、允许使用更高的学习率等

    • 一方面,BN 使每一层隐藏值分布主动居中,并将它们重新调整为学习到的最佳均值和方差,这种操作可能将参数的量级进行了统一,因此直觉上往往被认为可以使优化更加平滑
    • 另一方面,BN 有效性的科学性解释一度存在争议。15 年提出BN的论文声称 BN 减小了所谓的 内部协变量偏移internal covariate shift,因此可以提高模型性能,但其分析中假设了每层隐变量值都服从某种正态分布,这个假设过强了,很多后续研究指出了其问题。18 年的论文 How Does Batch Normalization Help Optimization? 认为 BN 的主要作用是使得整个损失函数的 landscape 更为平滑,从而使得我们可以更平稳地进行训练。相关分析可以参考苏神的博文
  • 示例代码参考自《动手学深度学习》7.5 节,适用于全连接层和卷积层,训练过程中使用滑动平均法计算 batch 数据的均值和方差;评估过程中使用最新的均值和方差结果

    class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2: # 全连接层shape = (1, num_features)else:             # 卷积层shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def batch_norm(self, X, gamma, beta, moving_mean, moving_var, eps, momentum):if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)                                       # (num_features,)var = ((X - mean) ** 2).mean(dim=0)                        # (num_features,)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。mean = X.mean(dim=(0, 2, 3), keepdim=True)                # (1,num_features,1,1) 保持X的形状,以便后面可以做广播运算var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True) # (1,num_features,1,1)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.datadef forward(self, X):# 如果X不在内存上,将moving_mean和moving_var,复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = self.batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y
    

2. Layer Normalization

  • LN 主要用于 NLP 领域,它对每个 token 的特征向量进行归一化计算。设某个 token 的特征向量为 x ∈ R d \pmb{x}\in \mathbb{R}^d xRd,LN 运算如下
    L N ( x ) = γ ⊙ x − μ ^ σ ^ + β . \mathrm{LN}(\mathbf{x})=\boldsymbol{\gamma} \odot \frac{\mathbf{x}-\hat{\boldsymbol{\mu}}}{\hat{\boldsymbol{\sigma}}}+\boldsymbol{\beta} . LN(x)=γσ^xμ^+β. 其中 ⊙ \odot 表示按位置乘, γ , β ∈ R d \pmb{\gamma}, \pmb{\beta}\in \mathbb{R}^d γ,βRd 和 是 拉伸参数scale偏移参数shift,代表着把第 i i i 个特征的 batch 分布的均值和方差移动到 β i , γ i \beta^i, \gamma^i βi,γi γ \pmb{\gamma} γ β \pmb{\beta} β 是需要与其他模型参数一起学习的参数 μ ^ \hat{\boldsymbol{\mu}} μ^ σ ^ \hat{\boldsymbol{\sigma}} σ^ 表示特征向量所有元素的均值和方差,如下计算
    μ ^ = 1 d ∑ x i ∈ x x i σ ^ 2 = 1 d ∑ x i ∈ x ( x i − μ ^ ) 2 + ϵ \begin{aligned} \hat{\boldsymbol{\mu}}&=\frac{1}{d} \sum_{x^i \in \mathbf{x}} x^i \\ \hat{\boldsymbol{\sigma}}^{2}&=\frac{1}{d} \sum_{x^i \in \mathbf{x}}\left(x^i-\hat{\boldsymbol{\mu}}\right)^{2}+\epsilon \end{aligned} μ^σ^2=d1xixxi=d1xix(xiμ^)2+ϵ 注意我们在方差估计值中添加一个小的常量 ϵ \epsilon ϵ,以确保我们永远不会尝试除以零

  • 给定一个长 l l l 的句子,LN 要进行 l l l 次归一化计算,之后对每个特征维度施加统一的拉伸和偏移,如下图所示
    在这里插入图片描述

  • 为什么 LN 比 BN 更适用于 Transformer 类模型呢,这是因为 transformer 模型是基于相似度的,把序列中的每个 token 的特征向量进行归一化有利于模型学习语义,第一步调整均值方差时,相当于对把各个 token 的特征向量缩放到统一的尺度,第二步施加 γ , β \pmb{\gamma, \beta} γ,β 时,相当于对所有 token 的特征向量进行了统一的 transfer,这不会破坏 token 特征向量间的相对角度,因此不会破坏学到的语义信息。与之相对的,BN 沿着特征维度进行归一化,这时对序列中各个 token 施加的 transfer 是不同的,破坏了 token 特征向量间的相对角度关系

  • Transformer 类模型中,LayerNorm 层有两种放置方式
    Pre Norm: x t + 1 = x t + F t ( Norm ⁡ ( x t ) ) Post Norm: x t + 1 = Norm ⁡ ( x t + F t ( x t ) ) \text{Pre Norm:} \quad \boldsymbol{x}_{t+1}=\boldsymbol{x}_{t}+F_{t}\left(\operatorname{Norm}\left(\boldsymbol{x}_{t}\right)\right) \\ \text{Post Norm:} \quad \boldsymbol{x}_{t+1}=\operatorname{Norm}\left(\boldsymbol{x}_{t}+F_{t}\left(\boldsymbol{x}_{t}\right)\right) Pre Norm:xt+1=xt+Ft(Norm(xt))Post Norm:xt+1=Norm(xt+Ft(xt)) 如下图所示
    在这里插入图片描述

    目前比较明确的结论是:同一设置之下,Pre Norm结构往往更容易训练,但最终效果通常不如Post Norm

    1. Pre Norm 更容易训练好理解,因为它的恒等路径更突出
    2. Pre Norm 中多层叠加的结果更多是增加宽度而不是深度,层数越多,这个层就越“虚”,这是因为 Pre Norm 结构无形地增加了模型的宽度而降低了模型的深度,而我们知道深度通常比宽度更重要,所以是无形之中的降低深度导致最终效果变差了。而 Post Norm 刚刚相反,它每Norm一次就削弱一次恒等分支的权重,所以 Post Norm 反而是更突出残差分支的,因此Post Norm中的层数更加有分量,起到了作用,一旦训练好之后效果更优。详细说明参考 为什么Pre Norm的效果不如Post Norm?
  • 过去 BERT 主流的时代往往使用 Post Norm,现在 GPT 时代模型规模都很大,因此大多用 Pre Norm 来稳定训练

3. RMSNorm

  • RMSNorm 是 LayerNorm 的一个简单变体,来自 2019 年的论文 Root Mean Square Layer Normalization,被 T5 和当前流行 lamma 模型所使用。其提出的动机是 LayerNorm 运算量比较大,所提出的RMSNorm 性能和 LayerNorm 相当,但是可以节省7%到64%的运算
  • RMSNorm和LayerNorm的主要区别在于RMSNorm不需要同时计算均值和方差两个统计量,而只需要计算均方根 Root Mean Square 这一个统计量,公式如下
    RMSNorm ( x ) = γ ⊙ x RMS ⁡ ( x ) where  RMS ⁡ ( x ) = 1 d ∑ x i ∈ x x i 2 + ϵ \text{RMSNorm}(\pmb{x})=\boldsymbol{\gamma} \odot\frac{\pmb{x}}{\operatorname{RMS}(x)} \quad \text{where \quad}\operatorname{RMS}(x)=\sqrt{\frac{1}{d} \sum_{x^i \in \mathbf{x}} x_{i}^{2} + \epsilon} RMSNorm(x)=γRMS(x)xwhere RMS(x)=d1xixxi2+ϵ
  • 论文 Do Transformer Modifications Transfer Across Implementations and Applications? 中做了比较充分的对比实验,显示出RMS Norm的优越性。一个直观的猜测是,计算均值所代表的 center 操作类似于全连接层的 bias 项,储存到的是关于预训练任务的一种先验分布信息,而把这种先验分布信息直接储存在模型中,反而可能会导致模型的迁移能力下降
  • 下面给出 Transformer Lamma 源码中实现的 RMSNorm
    class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype)
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/859968.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pyppeteer原理介绍和入门尝试

pyppeteer仓库地址:https://github.com/miyakogi/pyppeteer puppeteer仓库地址:https://github.com/search?qpuppeteer&typerepositories 因为有些网页是可以检测到是否是使用了selenium。并且selenium所谓的保护机制不允许跨域cookies保存以及登…

测试的基础知识大全【测试概念、分类、模型、流程、测试用例书写、用例设计、Bug、基础功能测试实战】

测试基础笔记 Day01阶段⽬标⼀、测试介绍⼆、测试常⽤分类2.1 阶段划分单元测试集成测试系统测试验收测试 2.2 代码可⻅度划分⿊盒测试:主要针对功能(阶段划分->系统测试)灰盒测试:针对接⼝测试(阶段划分->集成测…

【UEFI实战】HttpBoot

环境配置 首先下载tftpd工具,可以在phjounin / tftpd64 / Downloads — Bitbucket下载到,建议不要安装到C盘,因为可能无法修改其配置。配置tftpd工具的DHCP服务: 注意这里的IP地址需要跟实际网卡IP匹配。 下载Apache&#xff0c…

【TensorRT】TensorRT C# API 项目更新 (2):优化安装方式和代码

1. 项目介绍 NVIDIA TensorRT™ 是一款用于高性能深度学习推理的 SDK,包括深度学习推理优化器和运行时,可为推理应用程序提供低延迟和高吞吐量。基于 NVIDIA TensorRT 的应用程序在推理过程中的执行速度比纯 CPU 平台快 36 倍,使您能够优化在…

扣子/coze智能体开发的经验与避坑指南

近期,我计划几场关于分享智能体应用开发的活动。因此,我顺便总结了我在创建智能体过程中遇到的问题和解决方案,帮助大家避免类似的陷阱,提高智能体的性能和用户体验。以下是我总结的几点关键经验。 1. 人设与回复逻辑的提示词 在…

《C++ Primer》导学系列:第 8 章 - IO库

8.1 IO类 C标准库提供了一套丰富的输入输出&#xff08;IO&#xff09;类&#xff0c;用于处理数据的输入输出操作。这些类位于<iostream>头文件中&#xff0c;包括处理标准输入输出的istream和ostream类&#xff0c;处理文件输入输出的ifstream和ofstream类&#xff0c…

索引的分类和回表查询——Java全栈知识(29)

索引的分类和回表查询 Mysql 的索引按照类型可以分为以下几类&#xff0c;但是我们使用的 InnoDB 只支持主键索引&#xff0c;唯一索引&#xff0c;普通索引&#xff0c;并不支持全文索引。 1、聚集索引和二级索引 InnoDB 可以将索引分为两类分别是聚集索引和二级索引&…

编译原理大题自解(活前缀DFA、LR(0)分析表)

目录 4. (简答题) &#xff08;1&#xff09;给出识别活前缀的DFA &#xff08;2&#xff09;设计此文法的 LR(0)分析表 第一种解法 第二种解放 首先声明这是作者的写法&#xff08;不保证正确&#xff01;&#xff09;仅供参考。本题因为可能存在冲突的原因&#xff0c;所…

SpringCloud分布式微服务链路追踪方案:Zipkin

创作博客的目的是希望将自己掌握的知识系统地整理一下&#xff0c;并以博客的形式记录下来。这不仅是为了帮助其他有需要的人查阅相关内容&#xff0c;也是为了自己能够更好地巩固和加深对这些知识的理解。创作的时候也是对自己所学的一次复盘和总结&#xff0c;在创作的过程中…

【例子】webpack配合babel实现 es6 语法转 es5 案例 [通俗易懂]

首先来说一下实现 es6 转 es5 的一个简单步骤 1、新建一个项目&#xff0c;并且在命令行中初始化项目 npm init -y2、安装对应版本的 webpack webpack-cli(命令行工具) "webpack""webpack-cli"3、安装 Babel 核心库和相关的 loader "babel-core&qu…

PasteSpiderFile文件同步管理端使用说明(V24.6.21.1)

PasteSpider作为一款适合开发人员的部署管理工具&#xff0c;特意针对开发人员的日常情况做了一个PasteSpiderFile客户端&#xff0c;用于windows上的开发人员迅速的更新发布自己的最新代码到服务器上&#xff01; 虽然PasteSpider也支持svn/git的源码拉取&#xff0c;自动编译…

【自然语言处理系列】安装nltk_data和punkt库(亲测有效)

目录 一、下载nltk_data-gh-pages.zip数据文件 二、将nltk_data文件夹移到对应的目录 三、测试 四、成功调用punkt库 问题&#xff1a; 解决方案&#xff1a; 在使用自然语言处理库nltk时&#xff0c;许多初学者会遇到“nltk.download(punkt)”无法正常下载的问题。本…

Android Media Framework(七)MediaCodecService

Android引入Treble架构后&#xff0c;OpenMAX框架以HIDL Service的形式为System分区提供服务&#xff0c;本文将探讨该服务是如何启动&#xff0c;服务提供了什么内容&#xff0c;以及服务是如何被应用层所使用的。 1 概述 在Android的Treble架构中&#xff0c;为了确保系统的…

面试经典150题

打家劫舍 class Solution { public:int rob(vector<int>& nums) {int n nums.size();if(n 1){return nums[0];}vector<int> dp(n, 0);dp[0] nums[0];//有一间房可以偷//有两间房可以偷if(nums[1] > nums[0]){dp[1] nums[1];}else{dp[1] nums[0];}for …

react18 实现具名插槽

效果预览 技术要点 当父组件给子组件传递的 JSX 超过一个标签时&#xff0c;子组件接收到的 children 是一个数组&#xff0c;通过解析数组中各 JSX 的属性 slot &#xff0c;即可实现具名插槽的分发&#xff01; 代码实现 Father.jsx import Child from "./Child";…

【D3.js in Action 3 精译】第一部分 D3.js 基础知识

第一部分 D3.js 基础知识 欢迎来到 D3.js 的世界&#xff01;可能您已经迫不及待想要构建令人惊叹的数据可视化项目了。我们保证&#xff0c;这一目标很快就能达成&#xff01;但首先&#xff0c;我们必须确保您已经掌握了 D3.js 的基础知识。这一部分提到的概念将会在您后续的…

探秘神经网络激活函数:Sigmoid、Tanh和ReLU,解析非线性激活函数的神奇之处

引言 在神经网络中&#xff0c;激活函数扮演着至关重要的角色。它们赋予神经网络非线性的能力&#xff0c;使得网络具备学习和表示复杂函数关系的能力。本文将详细解析三种常见的激活函数&#xff1a;Sigmoid、Tanh和ReLU&#xff0c;揭开它们在神经网络中的奥秘。无论你是初学…

【十一】【QT开发应用】模拟腾讯会议登录界面设计UI

ui 加入会议的样式表 QPushButton { /* 前景色 */ color:#0054E6; /* 背景色 */ background-color:rgb(255,255,255); /* 边框风格 */ border-style:outset; /* 边框宽度 */ border-width:0.5px; /* 边框颜色 */ border-color:gray; /* 边框倒角 */ border-radius…

日常-----最爱的人

今日话题 大家好嗷&#xff0c;今天聊的技术可比之前的重要的多啊&#xff0c;哼哼&#xff0c;也不是今天&#xff0c;大家像我看齐嗷&#xff0c;我宣布个事情&#xff01;&#xff01;&#xff01; 于2024年6月21日晚上&#xff0c;本人遇到了这一生最爱的人 嘿嘿 这种事…

微信小程序 引入MiniProgram Design失败

这tm MiniProgramDesign 是我用过最垃圾的框架没有之一 我按照官网的指示安装居然能安装不成功,牛! 这里说明我是用js开发的 到以上步骤没有报错什么都没有,然后在引入组件的时候报错 Component is not found in path “./miniprogram _npm/vant/weapp/button/index” (using…