神经网络中BN层简介及位置分析

1. 简介

Batch Normalization是深度学习中常用的技巧,Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (Ioffe and Szegedy, 2015) 第一次介绍了这个方法。

这个方法的命名,明明是Standardization, 非要叫Normalization, 把本来就混用、意义不明的两个词更加搅得一团糟。那standardization 和 Normalization有什么区别呢?

一般是下面这样(X是输入数据集):

  • normalization(也叫 min-max scaling),一般译做 “归一化”:

  • standardization,一般译做 “标准化”:

Batch-Norm 是一个网络层,对中间结果作上面说的 standardization 操作。实际上 standardization 也可以叫做 Z-score normalization。所以可以这样理解,standardization 是一种特殊的 normalization。normalization 作为一个 scaling 的大类,包括 min-max scaling,standardization 等。

2. BatchNorm

对输入进行标准化的时候,计算每个特征在样本集合中的均值、方差;然后将每个样本的每个特征减去该特征的均值,并除以它的方差。用数学公式表示,即:

而所谓的BatchNorm, 就是神经网络中间,在一小撮batch样本中进行标准化。具体如下(B: batch size)

注意,BatchNorm作为神经网络的一层,是有两个参数\left ( \gamma, \beta \right )要训练的,分别称为拉伸和偏移参数。可能你会有疑问,既然已经对 u_{b} 作了标准化得到了\hat{u_{b}}  ,为什么还要用 \gamma, \beta将它“还原”呢?

实际上,设置这两个参数是为了给神经网络足够的自由度。如果经过训练\gamma \approx \hat{\sigma} _{batch},\beta \approx \hat{\mu} _{batch}, 说明神经网络认为,不需要进行批标准化即可使loss function最小化,我们也充分“尊重”它的选择。

3. BN 的特点

  • 使用 BatchNorm,我们可以尝试更大的学习率,从而加速收敛,但一般不会改变模型的精度;
  • BatchNorm 的效果依赖于 Batch size;一般需要较大的 Batch size(>16)才能有好的效果
  • 和 Dropout 一样,BatchNorm 在训练和推理时有不同的行为:训练时,它基于每个 batch 计算均值和方差,因此 batch size 必须足够大才能较好反映统计性质;推理时,BatchNorm 则直接用训练集整体的均值和方差进行标准化

训练集整体的均值和方差如何得到?——在每个batch的均值和方差计算中,通过移动平均估算得到。

4. BN 的位置

BatchNorm 究竟应该放在哪,现在还存在争议。很多人说应该放在激活函数之前,但也有声音说应该放在激活函数之后。思考一下,两种说法都有道理。举个简单的例子。

前一种说法是要对 \omega x+b作BatchNorm,这样可以保证 \omega x+b 在0附近, {\sigma}'(\omega x+b) 不至于太小;后一种说法 BatchNorm 的作用对象则直接是x ,这样可以控制梯度 \frac{\partial y}{\partial w} 在合理的范围内,不会因为 x 的极端取值而波动过大。

但现在看来,前一种声音是占上风的:将 BatchNorm 作用在全连接层和卷积层的输出上,激活函数之前。在全连接网络中,顺序是:线性组合+BatchNorm+Activation

对于全连接层,BatchNorm 作用在特征维上。假设输入矩阵大小是 m×n —— m 等于 batch size,即这个小批量中的样本数, n 表示特征数。我们要在每个特征上计算 m 个样本的均值和方差,也就是对每一列做计算。

在卷积神经网络中,顺序是:卷积层+BatchNorm+Activation+池化+全连接。要注意一点是,如果卷积层有K个卷积核(即K个通道),要对每个通道的输出分别做批标准化,且每个通道都拥有独立的拉伸和偏移参数。

对于卷积层,BatchNorm 作用在通道维上。我们先考虑一个 1×1 的卷积层,通道数为 k 。它其实就等价于神经元个数为 k 的全连接层。图片中每个像素点都由一个 k 维的向量表示,可以看作是像素点的 k 个特征。同一批量各个图片的各个像素点,就是不同的样本,共有 m×p×q 个样本, m,p,q 分别为 batch size、高、宽。

类比全连接层 BatchNorm 作用在特征维上,要在每个通道(即每个特征)上计算 m×p×q 个样本的均值和方差

设小批量中有m个样本。在单个通道上,假设卷积计算输出的高和宽分别为p和q。我们需要对该通道中m×p×q个元素同时标准化:对这些元素做标准化计算时,我们使用相同的均值和方差,即该通道中m×p×q个元素的均值和方差。——卷积神经网络之Batch Normalization(一)

5. BN的理解与延伸

BN 效果好是因为 BN 的存在会引入 mini-batch 内其他样本的信息,就会导致预测一个独立样本时,其他样本信息相当于正则项,使得 loss 曲面变得更加平滑,更容易找到最优解。相当于一次独立样本预测可以看多个样本,学到的特征泛化性更强,更加 general

Conv+BN+Relu 是卷积网络的一个常见组合。在模型推理时,BN 层的参数已经固定下来,本质就是一个线性变换。我们可以把 Conv+BN+Relu 进行算子融合,以加速模型推理

除了BN层,还有GN(Group Normalization)、LN(Layer Normalization、IN(Instance Normalization)这些个标准化方法,每个标注化方法都适用于不同的任务。

这个图很好地说明了BatchNorm、LayerNorm、InstanceNorm、GroupNorm的区别。N代表batch size;C代表卷积核个数(通道个数);H,W代表卷积结果的高和宽。

BatchNorm: 计算均值和方差时,考虑N * H * W 个元素;对每个通道分别做标准化

LayerNorm:计算均值和方差时,考虑C * H * W 个元素;对batch中的每个instance分别做标准化

InstanceNorm:计算均值和方差时,考虑H * W 个元素;对每个通道、batch中的每个instance分别做标准化

GroupNorm:介于LayerNorm和InstanceNorm二者之间,将C个通道分组,然后进行标准化。

直觉上来讲,GroupNorm把提取到类似特征的不同卷积核分到同一个group中。对这些卷积核进行标准化,确实make sense. 而且GroupNorm摆脱了对batch size的依赖。

GN在训练集上表现最好,在测试集上稍逊于BN(引自 Group Normalization (Yuxin & Kaiming, 2018))

6. BN vs LN

Transformer模型中用到了LayerNorm,着重对比一下LayerNorm和BatchNorm。

对于一个输入序列 (x1,x2,...,xn) ,每一个 xi 都是 d 维的向量。譬如输入序列是一个句子,每个单词 xi 都用一个 d 维的向量表示。

X轴是序列长度(n),Y轴是特征个数(d),Z轴是Batch size

此时BatchNorm是对图中蓝色框作标准化处理,就像我们上面说的——对每个特征分别做标准化;而LayerNorm针对每一个输入序列,对图中黄色框作标准化处理。总结来说,BatchNorm盯住每一个特征;而LayerNorm盯住的是每一个样本。

那么为什么Transformer模型要用LayerNorm而不是BatchNorm呢?

实际上,序列模型的背景下,BatchNorm有一个天然的硬伤,这使得它在所有序列模型中都不吃香:输入序列的长度(n)可能不一致。一般来说,我们会规定一个最长的序列长度,长度不够的序列用0填充。譬如下图这样,Batch中的序列长短不一。

如果用BatchNorm,以一个feature为例,它的标准化有效范围是蓝色的图,其余用0填充;如果是LayerNorm,对于4个序列,它们的标准化有效范围是黄色的图。

直觉上来说,对于BatchNorm的计算方法,当Batch中序列长度差距过大时,均值和方差的波动也会很大

但这个问题对于LayerNorm来说并不存在,因为它是在每一个序列内部计算均值和方差的。

这样,我们可以直观地理解,为什么BatchNorm对于序列模型并不好用;为什么Transformer要采用LayerNorm

7. BN代码实现

我们翻一翻常见的backbone的结构。可以看到在官方Pytorch的resnet.pyclass BasicBlock中,forward时的基本结构是Conv+BN+Relu:

# 省略了一些地方
class BasicBlock(nn.Module):def __init__(self,...) -> None:...self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x: Tensor) -> Tensor:identity = x# 常见的Conv+BN+Reluout = self.conv1(x)out = self.bn1(out)out = self.relu(out)# 又是Conv+BN+reluout = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

resnet作为我们常见的万年青backbone不是没有理由的,效果好速度快方便部署。当然还有很多其他优秀的backbone,这些backbone的内部结构也多为Conv+BN+Relu或者Conv+BN的结构。

参考资料:BatchNorm and its variants - 知乎normalization 和 standardization 到底什么区别?_为什么batch normalization使用standardization而不是normaliz-CSDN博客不论是训练还是部署都会让你踩坑的Batch Normalization - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/154858.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PostgreSQL 入门

文章目录 PostgreSQL介绍PostgreSQL和MySQL的区别PostgreSQL的安装PostgreSQL的配置远程连接配置配置数据库的日志 PostgreSQL基本操作用户操作权限操作 图形化界面安装总结 PostgreSQL介绍 PostgreSQL是一个功能强大的 开源 的关系型数据库,底层基于C实现。其开源…

模板初阶学习

✨前言✨ 📘 博客主页:to Keep博客主页 🙆欢迎关注,👍点赞,📝留言评论 ⏳首发时间:2023年11月21日 📨 博主码云地址:博主码云地址 📕参考书籍&…

【刷题专栏—突破思维】LeetCode 138. 随机链表的复制

前言 随机链表的复制涉及到复制一个链表,该链表不仅包含普通的next指针,还包含random指针,该指针指向链表中的任意节点或空节点。 文章目录 原地修改链表 题目链接: LeetCode 138. 随机链表的复制 原地修改链表 题目介绍&#xf…

拖拽场景遇到 iframe 无法拖拽的问题解决方案

描述一个场景:在网页中,分为上下两部分布局,下半部分显示操作日志,下半部分的区域高度是可拖拽调整的,但是如果下半部分嵌入一个 iframe 的时候,往上拖拽可以,但是往下拖拽,一旦到了…

分类问题的评价指标

一、logistic regression logistic regression也叫做对数几率回归。虽然名字是回归,但是不同于linear regression,logistic regression是一种分类学习方法。 同时在深度神经网络中,有一种线性层的输出也叫做logistic,他是被输入…

以太网_寻址

【架构图】 【ipconfig/all】 MAC地址:作用于本地网络,数据包发送到本地交换机或路由器后经判断目的地址是本地网络地址会转发给当前MAC地址对应的网线端口。 IP地址:供路由器寻址,会跟子网掩码进行运算,属于同一网络…

git问题: git@10.18.*.*: Permission denied (publickey,password)

遇到的问题: openSSH版本太高,openssh高版本默认禁止ssh-rsa加密算法,直接换ed25519 执行以下命令: 在.ssh目录下执行:ssh-keygen -t ed25519 -C “youremail.com” ssh-add ~/.ssh/id_ed25519 将id_ed25519.pub添加…

BigDecimal的常见陷阱

文章目录 BigDecimal概述BigDecimal常见陷阱1.使用BigDecimal的构造函数传入浮点数2.使用equals()方法进行数值比较3.使用不正确的舍入模式 总结: BigDecimal概述 BigDecimal 是 Java 中的一个类,用于精确表示和操作任意精度的十进制数。它提供了高精度的…

UE 材质,如何只取0~1之间的值,其余值抛弃

假如0~1,floor为0,abs为0,Saturate为0,1-x为1,很好 假如1~2,floor为1,abs为1,Saturate为1,1-x为0,很好 假如2~3,floor为2,abs为2&am…

软件测试/人工智能丨引领未来:软件测试中的人工智能

在数字化潮流的推动下,软件测试领域正在经历一场革命性的变革,而这场变革的关键推手正是人工智能(AI)。AI的引入不仅加速了测试过程,而且赋予了测试领域新的可能性,将我们带入了一个前所未有的未来。 智能…

【MySQL--->用户管理】

文章目录 [TOC](文章目录) 一、用户管理表二、基本操作三、用户权限分配给用户某个数据库中某个表的某个权限. grant 权限 on 库.表名 to 用户名主机名. ![在这里插入图片描述](https://img-blog.csdnimg.cn/fe8eb171ef9343c3a09bd64d4f0db5c1.png)分配给用户某个数据库中全部表…

13.Oracle通过JDBC连接Java

Oracle通过JDBC连接Java 一、什么是JDBC二、Oracle通过JDBC连接Java1、导入jar包1.1 下载jar包1.2 将jar包导入到java项目中1.3编译jar包 2、连接数据库2.1 编写jdbc工具类2.2 对数据进行基本操作 一、什么是JDBC JDBC(Java Database Connectivity)是Jav…

微波功率计/频率计-87234系列USB峰值/平均功率计

仪器仪表 苏州新利通 87234系列 USB峰值/平均功率计 频率范围覆盖:50MHz~67GHz 一款基于USB 2.0接口的二极管检波式宽带功率测量仪器 国产思仪功率计 01 产品综述 87234D/E/F/L USB峰值/平均功率计是一款基于USB 2.0接口的二极管检波式宽带功率测…

PowerShell无人参与安装最新版本SQL Server Management Studio (SSMS)

文章目录 下载SQL Server Management Studio (SSMS)Power Shell实现无人安装推荐阅读 下载SQL Server Management Studio (SSMS) SSMS 19.2 是最新的正式发布 (GA) 版本。 如果已经安装了 SSMS 19 预览版,需要在安装 SSMS 19.2 之前将其卸载。 如果安装了 SSMS 19.…

负载均衡Ribbon和Feign的使用与区别

Ribbon 的介绍 Spring Cloud Ribbon 是基于Netflix Ribbon 实现的一套客户端负载均衡的工具。主要功能是提供客户端的软件负载均衡和服务调用。Ribbon 客户端组件提供一系列完善的配置项如连接超时,重试等。简单的说,就是在配置文件中列出Load Balancer…

JVM基础- 垃圾回收器

基本介绍 Java虚拟机(JVM)中的垃圾回收器是用来自动管理内存的关键组件。它负责识别并回收不再使用的内存,从而防止内存泄漏。不同的JVM实现提供了多种垃圾回收器,每种回收器都有其特定的使用场景和性能特点。以下是一些常见的JV…

【用unity实现100个游戏之16】Unity程序化生成随机2D地牢游戏2(附项目源码)

文章目录 先看看最终效果前言生成走廊生成房间修复死胡同增加走廊宽度获取走廊位置信息集合方法一方法二 源码完结 先看看最终效果 前言 上期已经实现了房间的生成,本期紧跟着上期内容,生成走廊并结合上期内容生成连通的房间。 生成走廊 修改Procedur…

集成多元算法,打造高效字面文本相似度计算与匹配搜索解决方案,助力文本匹配冷启动[BM25、词向量、SimHash、Tfidf、SequenceMatcher]

搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排)、系统架构、常见问题、算法项目实战总结、技术细节以及项目实战(含码源) 专栏详细介绍:搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排)、系统架构、常见问题、算法项目实战总结、技术…

zabbix的安装配置,邮件告警,钉钉告警

zabbix监控架构 zabbix优点 开源,无软件成本投入server对设备性能要求低支持设备多,自带多种监控模板支持分布式集中管理,有自动发现功能,可以实现自动化监控开放式接口,扩展性强,插件编写容易当监控的item…

力扣C++学习笔记——C++ 给vector去重

要使用std::set对std::vector进行去重操作,您可以将向量中的元素插入到集合中,因为std::set会自动去除重复元素。然后,您可以将集合中的元素重新存回向量中。以下是一个示例代码,演示如何使用std::set对std::vector进行去重&#…