神经网络中BN层简介及位置分析

1. 简介

Batch Normalization是深度学习中常用的技巧,Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (Ioffe and Szegedy, 2015) 第一次介绍了这个方法。

这个方法的命名,明明是Standardization, 非要叫Normalization, 把本来就混用、意义不明的两个词更加搅得一团糟。那standardization 和 Normalization有什么区别呢?

一般是下面这样(X是输入数据集):

  • normalization(也叫 min-max scaling),一般译做 “归一化”:

  • standardization,一般译做 “标准化”:

Batch-Norm 是一个网络层,对中间结果作上面说的 standardization 操作。实际上 standardization 也可以叫做 Z-score normalization。所以可以这样理解,standardization 是一种特殊的 normalization。normalization 作为一个 scaling 的大类,包括 min-max scaling,standardization 等。

2. BatchNorm

对输入进行标准化的时候,计算每个特征在样本集合中的均值、方差;然后将每个样本的每个特征减去该特征的均值,并除以它的方差。用数学公式表示,即:

而所谓的BatchNorm, 就是神经网络中间,在一小撮batch样本中进行标准化。具体如下(B: batch size)

注意,BatchNorm作为神经网络的一层,是有两个参数\left ( \gamma, \beta \right )要训练的,分别称为拉伸和偏移参数。可能你会有疑问,既然已经对 u_{b} 作了标准化得到了\hat{u_{b}}  ,为什么还要用 \gamma, \beta将它“还原”呢?

实际上,设置这两个参数是为了给神经网络足够的自由度。如果经过训练\gamma \approx \hat{\sigma} _{batch},\beta \approx \hat{\mu} _{batch}, 说明神经网络认为,不需要进行批标准化即可使loss function最小化,我们也充分“尊重”它的选择。

3. BN 的特点

  • 使用 BatchNorm,我们可以尝试更大的学习率,从而加速收敛,但一般不会改变模型的精度;
  • BatchNorm 的效果依赖于 Batch size;一般需要较大的 Batch size(>16)才能有好的效果
  • 和 Dropout 一样,BatchNorm 在训练和推理时有不同的行为:训练时,它基于每个 batch 计算均值和方差,因此 batch size 必须足够大才能较好反映统计性质;推理时,BatchNorm 则直接用训练集整体的均值和方差进行标准化

训练集整体的均值和方差如何得到?——在每个batch的均值和方差计算中,通过移动平均估算得到。

4. BN 的位置

BatchNorm 究竟应该放在哪,现在还存在争议。很多人说应该放在激活函数之前,但也有声音说应该放在激活函数之后。思考一下,两种说法都有道理。举个简单的例子。

前一种说法是要对 \omega x+b作BatchNorm,这样可以保证 \omega x+b 在0附近, {\sigma}'(\omega x+b) 不至于太小;后一种说法 BatchNorm 的作用对象则直接是x ,这样可以控制梯度 \frac{\partial y}{\partial w} 在合理的范围内,不会因为 x 的极端取值而波动过大。

但现在看来,前一种声音是占上风的:将 BatchNorm 作用在全连接层和卷积层的输出上,激活函数之前。在全连接网络中,顺序是:线性组合+BatchNorm+Activation

对于全连接层,BatchNorm 作用在特征维上。假设输入矩阵大小是 m×n —— m 等于 batch size,即这个小批量中的样本数, n 表示特征数。我们要在每个特征上计算 m 个样本的均值和方差,也就是对每一列做计算。

在卷积神经网络中,顺序是:卷积层+BatchNorm+Activation+池化+全连接。要注意一点是,如果卷积层有K个卷积核(即K个通道),要对每个通道的输出分别做批标准化,且每个通道都拥有独立的拉伸和偏移参数。

对于卷积层,BatchNorm 作用在通道维上。我们先考虑一个 1×1 的卷积层,通道数为 k 。它其实就等价于神经元个数为 k 的全连接层。图片中每个像素点都由一个 k 维的向量表示,可以看作是像素点的 k 个特征。同一批量各个图片的各个像素点,就是不同的样本,共有 m×p×q 个样本, m,p,q 分别为 batch size、高、宽。

类比全连接层 BatchNorm 作用在特征维上,要在每个通道(即每个特征)上计算 m×p×q 个样本的均值和方差

设小批量中有m个样本。在单个通道上,假设卷积计算输出的高和宽分别为p和q。我们需要对该通道中m×p×q个元素同时标准化:对这些元素做标准化计算时,我们使用相同的均值和方差,即该通道中m×p×q个元素的均值和方差。——卷积神经网络之Batch Normalization(一)

5. BN的理解与延伸

BN 效果好是因为 BN 的存在会引入 mini-batch 内其他样本的信息,就会导致预测一个独立样本时,其他样本信息相当于正则项,使得 loss 曲面变得更加平滑,更容易找到最优解。相当于一次独立样本预测可以看多个样本,学到的特征泛化性更强,更加 general

Conv+BN+Relu 是卷积网络的一个常见组合。在模型推理时,BN 层的参数已经固定下来,本质就是一个线性变换。我们可以把 Conv+BN+Relu 进行算子融合,以加速模型推理

除了BN层,还有GN(Group Normalization)、LN(Layer Normalization、IN(Instance Normalization)这些个标准化方法,每个标注化方法都适用于不同的任务。

这个图很好地说明了BatchNorm、LayerNorm、InstanceNorm、GroupNorm的区别。N代表batch size;C代表卷积核个数(通道个数);H,W代表卷积结果的高和宽。

BatchNorm: 计算均值和方差时,考虑N * H * W 个元素;对每个通道分别做标准化

LayerNorm:计算均值和方差时,考虑C * H * W 个元素;对batch中的每个instance分别做标准化

InstanceNorm:计算均值和方差时,考虑H * W 个元素;对每个通道、batch中的每个instance分别做标准化

GroupNorm:介于LayerNorm和InstanceNorm二者之间,将C个通道分组,然后进行标准化。

直觉上来讲,GroupNorm把提取到类似特征的不同卷积核分到同一个group中。对这些卷积核进行标准化,确实make sense. 而且GroupNorm摆脱了对batch size的依赖。

GN在训练集上表现最好,在测试集上稍逊于BN(引自 Group Normalization (Yuxin & Kaiming, 2018))

6. BN vs LN

Transformer模型中用到了LayerNorm,着重对比一下LayerNorm和BatchNorm。

对于一个输入序列 (x1,x2,...,xn) ,每一个 xi 都是 d 维的向量。譬如输入序列是一个句子,每个单词 xi 都用一个 d 维的向量表示。

X轴是序列长度(n),Y轴是特征个数(d),Z轴是Batch size

此时BatchNorm是对图中蓝色框作标准化处理,就像我们上面说的——对每个特征分别做标准化;而LayerNorm针对每一个输入序列,对图中黄色框作标准化处理。总结来说,BatchNorm盯住每一个特征;而LayerNorm盯住的是每一个样本。

那么为什么Transformer模型要用LayerNorm而不是BatchNorm呢?

实际上,序列模型的背景下,BatchNorm有一个天然的硬伤,这使得它在所有序列模型中都不吃香:输入序列的长度(n)可能不一致。一般来说,我们会规定一个最长的序列长度,长度不够的序列用0填充。譬如下图这样,Batch中的序列长短不一。

如果用BatchNorm,以一个feature为例,它的标准化有效范围是蓝色的图,其余用0填充;如果是LayerNorm,对于4个序列,它们的标准化有效范围是黄色的图。

直觉上来说,对于BatchNorm的计算方法,当Batch中序列长度差距过大时,均值和方差的波动也会很大

但这个问题对于LayerNorm来说并不存在,因为它是在每一个序列内部计算均值和方差的。

这样,我们可以直观地理解,为什么BatchNorm对于序列模型并不好用;为什么Transformer要采用LayerNorm

7. BN代码实现

我们翻一翻常见的backbone的结构。可以看到在官方Pytorch的resnet.pyclass BasicBlock中,forward时的基本结构是Conv+BN+Relu:

# 省略了一些地方
class BasicBlock(nn.Module):def __init__(self,...) -> None:...self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x: Tensor) -> Tensor:identity = x# 常见的Conv+BN+Reluout = self.conv1(x)out = self.bn1(out)out = self.relu(out)# 又是Conv+BN+reluout = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

resnet作为我们常见的万年青backbone不是没有理由的,效果好速度快方便部署。当然还有很多其他优秀的backbone,这些backbone的内部结构也多为Conv+BN+Relu或者Conv+BN的结构。

参考资料:BatchNorm and its variants - 知乎normalization 和 standardization 到底什么区别?_为什么batch normalization使用standardization而不是normaliz-CSDN博客不论是训练还是部署都会让你踩坑的Batch Normalization - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/154858.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PostgreSQL 入门

文章目录 PostgreSQL介绍PostgreSQL和MySQL的区别PostgreSQL的安装PostgreSQL的配置远程连接配置配置数据库的日志 PostgreSQL基本操作用户操作权限操作 图形化界面安装总结 PostgreSQL介绍 PostgreSQL是一个功能强大的 开源 的关系型数据库,底层基于C实现。其开源…

面向对象程序设计1-类的定义和使用

第1关:数字时钟走字 任务描述 本关任务:本题中已给出一个时钟类的定义,请模拟数字时钟走字过程。 相关知识 为了完成本关任务,你需要掌握:1.类和对象。 编程要求 根据提示,在右侧编辑器补充代码&…

模板初阶学习

✨前言✨ 📘 博客主页:to Keep博客主页 🙆欢迎关注,👍点赞,📝留言评论 ⏳首发时间:2023年11月21日 📨 博主码云地址:博主码云地址 📕参考书籍&…

Scala---WordCount

一、创建Maven项目导入pom.xml文件 安装Maven仓库管理工具,版本要求是3.2版本以上。新建Maven项目,配置pom.xml。导入必要的包。 二、Spark-Scala版本的WordCount 1.val conf new SparkConf() 2.conf.setMaster("local") 3.conf.setAppNam…

4、FFmpeg命令行操作4

ffmpeg命令参数说明 主要参数: -i 设定输入流 -f 设定输出格式(format) -ss 开始时间 -t 时间长度 音频参数: -aframes 设置要输出的音频帧数 -b:a 音频码率 -ar 设定采样率 -ac 设定声音的Channel数 -acodec 设定声音编解码器,如果用copy表示原始编解码数…

【刷题专栏—突破思维】LeetCode 138. 随机链表的复制

前言 随机链表的复制涉及到复制一个链表,该链表不仅包含普通的next指针,还包含random指针,该指针指向链表中的任意节点或空节点。 文章目录 原地修改链表 题目链接: LeetCode 138. 随机链表的复制 原地修改链表 题目介绍&#xf…

拖拽场景遇到 iframe 无法拖拽的问题解决方案

描述一个场景:在网页中,分为上下两部分布局,下半部分显示操作日志,下半部分的区域高度是可拖拽调整的,但是如果下半部分嵌入一个 iframe 的时候,往上拖拽可以,但是往下拖拽,一旦到了…

分类问题的评价指标

一、logistic regression logistic regression也叫做对数几率回归。虽然名字是回归,但是不同于linear regression,logistic regression是一种分类学习方法。 同时在深度神经网络中,有一种线性层的输出也叫做logistic,他是被输入…

【python学习】基础篇-常用模块-shutil文件和目录操作

shutil模块是Python标准库中的一个模块,提供了对文件和目录进行高级操作的函数。 以下是shutil模块的一些常用函数: 1.复制文件: 将源文件src复制到目标文件dst。如果follow_symlinks为True,则会跟随符号链接。 shutil.copy(src, dst, *, f…

以太网_寻址

【架构图】 【ipconfig/all】 MAC地址:作用于本地网络,数据包发送到本地交换机或路由器后经判断目的地址是本地网络地址会转发给当前MAC地址对应的网线端口。 IP地址:供路由器寻址,会跟子网掩码进行运算,属于同一网络…

git问题: git@10.18.*.*: Permission denied (publickey,password)

遇到的问题: openSSH版本太高,openssh高版本默认禁止ssh-rsa加密算法,直接换ed25519 执行以下命令: 在.ssh目录下执行:ssh-keygen -t ed25519 -C “youremail.com” ssh-add ~/.ssh/id_ed25519 将id_ed25519.pub添加…

Java 数据结构、集合框架、ArrayList

一、Java数据结构: Java中的数据结构主要包含以下几种接口和类:枚举、位集合、向量、栈、字典、哈希表、属性。 枚举接口定义一种从数据结构中取回连续元素的方式。 位集合实现了一组可以单独设置和清除的位或标志。 向量类于传统数组相似&#xff0…

信也科技发布2023年Q3财报:数字金融服务业务增长稳健,持续拉动实体消费

11月21日,信也科技(NYSE:FINV)公布2023年第三季度未经审计的财务报告。财报显示,信也科技三季度在国内、国际市场延续稳健增长态势,实现季度营收31.98亿元(人民币,下同)&…

LeetCode 每日一题 2023/11/13-2023/11/19

记录了初步解题思路 以及本地实现代码;并不一定为最优 也希望大家能一起探讨 一起进步 目录 11/13 307. 区域和检索 - 数组可修改11/14 1334. 阈值距离内邻居最少的城市11/15 2656. K 个元素的最大和11/16 2760. 最长奇偶子数组11/17 2736. 最大和查询11/18 2342. 数…

【Python入门教程】OpenCV在图片/视频上添加Logo(水印)

还是老样子,最近项目需要在视频上添加logo,所以就找了一下相关资料,然后写了一段代码,今天给大家分享一下如何使用Python的OpenCV库给图片或视频添加水印和logo。 一、导入库 OpenCV库导入的时候是cv2 import cv2 二、代码部分 …

BigDecimal的常见陷阱

文章目录 BigDecimal概述BigDecimal常见陷阱1.使用BigDecimal的构造函数传入浮点数2.使用equals()方法进行数值比较3.使用不正确的舍入模式 总结: BigDecimal概述 BigDecimal 是 Java 中的一个类,用于精确表示和操作任意精度的十进制数。它提供了高精度的…

UE 材质,如何只取0~1之间的值,其余值抛弃

假如0~1,floor为0,abs为0,Saturate为0,1-x为1,很好 假如1~2,floor为1,abs为1,Saturate为1,1-x为0,很好 假如2~3,floor为2,abs为2&am…

软件测试/人工智能丨引领未来:软件测试中的人工智能

在数字化潮流的推动下,软件测试领域正在经历一场革命性的变革,而这场变革的关键推手正是人工智能(AI)。AI的引入不仅加速了测试过程,而且赋予了测试领域新的可能性,将我们带入了一个前所未有的未来。 智能…

【MySQL--->用户管理】

文章目录 [TOC](文章目录) 一、用户管理表二、基本操作三、用户权限分配给用户某个数据库中某个表的某个权限. grant 权限 on 库.表名 to 用户名主机名. ![在这里插入图片描述](https://img-blog.csdnimg.cn/fe8eb171ef9343c3a09bd64d4f0db5c1.png)分配给用户某个数据库中全部表…

Golang CSV Reader

导言 CSV(逗号分隔值)是一种常见的文件格式,用于存储和交换数据。它简单易用,具有广泛的应用场景,因此在处理和解析 CSV 文件时需要一个高效和可靠的方法。Golang 提供了一个强大的 CSV Reader 库,可以简化…