BatchNormalization和Layer Normalization解析

Batch Normalization

是google团队2015年提出的,能够加速网络的收敛并提升准确率

1.Batch Normalization原理

图像预处理过程中通常会对图像进行标准化处理,能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应的feature map的数据要满足分布规律)。而我们BN的目的就是使feature map满足均值为0,方差为1的分布规律。

对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,其中x^1就代表我们的R通道所对应的特征矩阵,依次类推。标准化处理也就是分别对R通道,G通道,B通道进行处理。

让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后再进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的BN,也就是计算一个Batch数据的feature map然后进行标准化(batch越大越接近整个数据集的分布,效果越好)。

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程,假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel为2,那么x^1代表batch的所有的feature的channel1的数据。然后分别计算x^1和x^2的均值和方差。然后再根据标准差计算公式分别计算每个channel 的值(\varepsilon是很小的常量,放置分母为0的情况)。在训练过程中要去不断地计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在训练完后我们可以近似认为所统计的均值和方差就等于整个训练集的均值和方差。然后再我们的验证以及预测过程中,就使用统计得到的均值和方差进行标准化处理。

\gamma是用来调整数值分布的方差大小,默认为1,\beta是用来调节数值均值的位置,默认值为0。这两个参数实在反向传播过程中学习到的。

2.使用Pytorch进行实验

在训练过程中,均值和方差是同通过计算当前批次数据得到的记录为\mu _{now},\delta_{now} ^{2},而我们的验证以及预测过程中使用的均值方差是一个统计量为\mu _{statistic},\delta _{statistic}^{2}。具体更新策略如下,其中momentum默认取0.1:

\mu _{statistic+1} = 0.9*\mu _{statistic}+0.1*\mu _{now}\\ \delta _{statistic+1}^{2} = 0.9*\delta _{statistic}^{2}+0.1*\delta _{now}^{2}

(1)bn_process函数是自定义的bn处理方法验证是否和使用官方bn处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。

(2)初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化\beta=0,\gamma=1。

import numpy as np
import torch.nn as nn
import torchdef bn_process(feature, mean, var):feature_shape = feature.shapefor i in range(feature_shape[1]):# [batch,channel, height, weight]feature_t = feature[:, i, :, :]mean_t = feature_t.mean()#总体标准差std_t1 = feature_t.std()#样本标准差std_t2 = feature_t.std(ddof = 1)#bn process#这里记得加上eps和pytorch保持一致feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2+ 1e-5)#更新计算均值mean[i]  = mean[i]*0.9 + mean_t * 0.1var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1print(feature)#随机生成一个batch为2,channel为2,height=width=2的特征向量
#[batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
#初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
#print(feature1.numpy())#注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)bn = nn.BatchNorm2d(2, eps =  1e-5)
output = bn(feature1)
print(output)

 

3.使用BN时需要注意的问题

(1)训练时要将training采纳数设置为True,在验证时将training参数设置为False。在Pytorch中了可以通过创建模型的model.train()和model.eval()方法控制。

(2)batch size尽可能设置大点,设置小后表现很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

(3)建议将bn层放在卷积层和激活层之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,及时使用了偏置bias求出的结果也是一样的。

 


Layer Normalization

Layer Normalization针对自然语言处理提出的,为什么不用BN呢,因为在RNN这类时序网络中,时序的长度并不是一个定值(网络深度不一定相同),比如每句话的长短都不一定相同,所以很难去使用BN,所以作者提出了Layer Normalization(图像处理领域BN比LN更有效),但现在很多人将自然语言领域的模型用来处理图像,比如Vision Transformer,此时会涉及到LN。

直接看Pytorch 官方给出的关于LayerNorm 的介绍。不同的是,BN是对一个batch数据的每个channel进行Norm处理,一个for循环,但LN是对单个数据的制定维度进行Norm处理与batch无关而且BN中训练时是需要累计moving_mean和moving_var两个变量的(所以BN中有4个参数moving_mean,moving_var,\beta ,\gamma),但LN不需要累计只有\beta ,\gamma两个参数。

在Pytorch的LayerNorm类中有个normalized_shape参数,可以指定要Norm的维度(注意,函数说明中the last certain number of dimensions,指定的维度必须是从最后一维开始)。比如我们的数据shape是[4,2,3],那么normalized_shape可以是[3](最后一维进行Norm处理),也可以是[2,3](Norm最后两个维度),也可以是整个维度[4,2,3],但不能是[2]或者[4,2],否则会报错。

y = \frac{x-E[X]}{\sqrt{Var[x]+\varepsilon}}*\gamma +\beta

import torch
import torch.nn as nndef layer_norm_process(feature:torch.Tensor, beta=0.,gamma = 1.,eps=1e-5):var_mean = torch.var_mean(feature, dim = -1, unbiased = False)#均值mean = var_mean[1]#方差var = var_mean[0]#layer norm processfeature  = (feature - mean[..., None]) / torch.sqrt(var[..., None] + eps)feature = feature*gamma+betareturn featuredef main():t = torch.randn(4, 2, 3)print(t)#仅在最后一个维度上做norm处理norm = nn.LayerNorm(normalized_shape= t.shape[-1], eps = 1e-5)#官方layer norm处理t1 = norm(t)#自己实现的layer norm处理t2 = layer_norm_process(t, eps = 1e-5)print("t1:\n",t1)print("t2:\n",t2)if __name__ == '__main__':main()
tensor([[[ 0.8512,  0.4201, -0.3457],[ 0.4701, -0.0647,  0.0733]],[[-0.9950, -0.4634,  0.0540],[ 0.4096,  0.4037, -0.0914]],[[-2.3165,  1.3059,  0.3183],[-0.9716,  0.4956,  0.4524]],[[-0.6209, -0.5958,  0.3212],[-0.8762,  0.3176, -0.5427]]])
t1:tensor([[[ 1.0963,  0.2254, -1.3218],[ 1.3697, -0.9893, -0.3804]],[[-1.2302,  0.0110,  1.2192],[ 0.7198,  0.6942, -1.4140]],[[-1.3642,  1.0050,  0.3591],[-1.4137,  0.7385,  0.6752]],[[-0.7355, -0.6783,  1.4138],[-1.0123,  1.3614, -0.3490]]], grad_fn=<NativeLayerNormBackward0>)
t2:tensor([[[ 1.0963,  0.2254, -1.3218],[ 1.3697, -0.9893, -0.3804]],[[-1.2302,  0.0110,  1.2192],[ 0.7198,  0.6942, -1.4140]],[[-1.3642,  1.0050,  0.3591],[-1.4137,  0.7385,  0.6752]],[[-0.7355, -0.6783,  1.4138],[-1.0123,  1.3614, -0.3490]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/28948.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RN全局封装一个toast弹窗提示(在任意地方只要调用方法就能使用)

创建Toast.js let toastInstance = null;export const showToast = options => {if (toastInstance) {// toastInstance有值的话toastInstance(options);} else {console.error(Toast instance is not initialized);} };export const setToastInstance = instance => {t…

SQL语言-关系数据库的标准语言

一、SQL是关系数据库的标准语言&#xff0c;兼有关系代数和关系演算的特点。具有数据定义、数据查询、数据更新、数据控制等功能。 二、SQL的数据定义主要包括对基本表TABLE、视图VIEW、索引INDEX的创建CREATE和撤销DROP。通过定义表的主键、外键&#xff0c;以及有关的约束&a…

从hip cuda kernel 的汇编语言来感受 AMD GPU内部工作方式

0, 无参数 kernel 汇编语言示例 ./param_00.hip __global__ void WWWWW() {((int*)0x8888888)[3] 0x77777; } ../../../local_amdgpu/bin/clang -O1 --cuda-device-only --offload-archgfx906 -S ./param_00.hip -o param_00.s _Z5WWWWWv: ; …

一招解决家里粉尘螨虫太多难题?家用空气净化器哪款品牌效果好?

一到夏天&#xff0c;两天不打扫家里&#xff0c;家里就会布满一层粉尘。而且春夏的气候也是粉尘螨虫生长和繁殖疯狂时期&#xff0c;一不注意室内空气污染卫生的情况下&#xff0c;就会加剧尘螨的滋生&#xff0c;体质弱、敏感的人群生活在这样的空气环境下&#xff0c;还会增…

敏捷开发时代,彻底结束了

最近&#xff0c;我收到一位读者的私信&#xff0c;他最近“内耗”得非常厉害&#xff0c;他可能一时兴起把我的私信当作了吐槽箱。 他们公司一直实行敏捷的管理模式&#xff0c;复盘发现了一个问题&#xff1a;发布与迭代具有强相关性&#xff0c;一个迭代就发布一次&#xf…

Hadoop安装和测试

一&#xff0c;下载 地址&#xff1a;Index of /dist/hadoop/common 选择3.3.6版本&#xff08;最新版本之前的一个版本&#xff0c;一般比较稳定&#xff09; 二&#xff0c;解压 解压到/data/module目录&#xff0c;这里随便自定义就好。 tar -zxvf hadoop-3.3.6.tar.gz …

从《2024年人工智能指数报告》 看AI的最新发展趋势

本文首发于公众号“AntDream”&#xff0c;欢迎微信搜索“AntDream”或扫描文章底部二维码关注&#xff0c;和我一起每天进步一点点 《2024年人工智能指数报告》是由斯坦福大学的“以人为本”人工智能研究所&#xff08;Stanford HAI&#xff09;发布的&#xff0c;具体发布时间…

百货商场:打造品质生活

走进我们的百货商场&#xff0c;仿佛置身于一个五彩斑斓的梦幻世界。百货&#xff0c;不仅仅是购物的场所&#xff0c;更是一种品质生活的体验。 在这里&#xff0c;您可以找到最适合自己的商品选择。从家居用品到时尚服饰&#xff0c;从美食佳肴到美妆护肤&#xff0c;每一样商…

深入探索Java开发世界:Java基础~类型分析大揭秘

文章目录 一、基本数据类型二、封装类型三、类型转换四、集合类型五、并发类型 Java基础知识&#xff0c;类型知识点梳理~ 一、基本数据类型 Java的基本数据类型是语言的基础&#xff0c;它们直接存储在栈内存中&#xff0c;具有固定的大小和不变的行为。 八种基本数据类型的具…

Vue46-render函数

一、非单文件和单文件的main.js对比 1-1、非单文件的main.js 1-2、 单文件的main.js 将单文件的main.js中的render函数变成非单文件的main.js中的template形式&#xff0c;报如下错误&#xff1a; 解决方式&#xff1a; 二、解决方式 2-1、引入完成版的vue.js 精简版的vue&a…

TEA 加密的 Java 实现

import java.nio.ByteBuffer; import java.nio.ByteOrder;public class TeaUtils {private static final int DELTA 0x9E3779B9;private static final int ROUND 32;private static final String KEY "password";/*** 加密字符串&#xff0c;使用 TEA 加密算法*/p…

推广结算统计,Xinstall助您轻松掌握每一分投入与回报!

在移动互联网时代&#xff0c;App的推广与运营离不开精准的数据支持和高效的结算系统。然而&#xff0c;面对众多的推广渠道和复杂的结算流程&#xff0c;如何确保每一分投入都能得到合理的回报&#xff0c;成为了众多企业和开发者关注的焦点。今天&#xff0c;我们就来聊聊如何…

半监督学习

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 介绍一、Self Training自训练1、介绍2、代码示例3、参数解释 二、Label Propagation&#xff08;标签传播&#xff09;1、介绍2、代码示例3、参数解释 三、Label Spread…

深入剖析Java线程池之“newWorkStealingPool“

1. 概述 newWorkStealingPool 是Java 8中引入的一个新型线程池,它基于ForkJoinPool实现,并采用了“工作窃取”(Work-Stealing)算法。这种线程池特别适用于可并行化且计算密集型的任务,能够充分利用多核CPU资源,提高任务执行效率。 2. 工作窃取算法(Work-Stealing Algor…

618狂欢日,美味产品齐上阵,超值优惠等你享

这个充满激情与活力的6月&#xff0c;我们带着满满的诚意与惊喜&#xff0c;为广大美食爱好者们开启一场独特的618狂欢之旅。 当我们提及甘肃&#xff0c;那丰富多样的甘肃传统美食便是不得不说的瑰宝。烤馍&#xff0c;油饼&#xff0c;锅盔、擀面皮、浆水等每一种美食都…

java算法:插入排序

这里写目录标题 基本使用优缺点尝试优化二分查找插入减少交换操作 基本使用 插入排序是一种简单直观的排序算法&#xff0c;它的工作原理是将待排序的数组分为已排序和未排序两部分&#xff0c;逐步将未排序部分的元素插入到已排序部分中的正确位置&#xff0c;直到整个数组有…

ffmpeg的部署踩坑及简单使用方式

ffmpeg的使用方式有以下几种: 使用原生安装包 直接在ffmpeg官网上下载安装该软件,加入到环境变量中就可以使用了 优点:简单,灵活,代码中也不用添加其他第三方的包 缺点:需要手动安装ffmpeg,这点比较麻烦 部署-windows 在windows环境下,有时就算加入到了环境变量,…

你知道花洒其实起源于中国古代吗?

花洒作为日常生活中不可或缺的一部分&#xff0c;其发展历程不仅见证了人类文明的进步&#xff0c;也反映了生活美学的演变。从最初的简单构想到现代的智能化设计&#xff0c;花洒的变迁历程是一部生动的人类生活史。 早在隋朝时期&#xff0c;我们的祖先就已经有了花洒的初步构…

【Go语言】Go语言中的接口类型

Go语言中的接口类型 接口&#xff08;interface&#xff09;定义了一个对象的行为规范&#xff0c;只定义规范不实现&#xff0c;由具体的对象来实现规范的细节。 1.接口类型 1.1 接口类型的说明 Go语言中 接口&#xff08;interface&#xff09; 是一种抽象的类型。 接口…

《纪元 1800》好玩吗? 苹果电脑能玩《纪元 1800》吗?

《纪元1800》是一款不错的策略游戏&#xff0c;这款游戏因为画面和玩法独特深受玩家们的喜爱。下面我们来看看《纪元 1800》好玩吗&#xff0c;苹果电脑能玩《纪元 1800》吗的相关内容。 一、《纪元1800》好玩吗 《纪元1800》是一款备受瞩目的策略游戏。下面让我们来看看这款…