神经网络:卷积神经网络中的BatchNorm

在这里插入图片描述

一、BN介绍

1.原理

在机器学习中让输入的数据之间相关性越少越好,最好输入的每个样本都是均值为0方差为1。在输入神经网络之前可以对数据进行处理让数据消除共线性,但是这样的话输入层的激活层看到的是一个分布良好的数据,但是较深的激活层看到的的分布就没那么完美了,分布将变化的很严重。这样会使得训练神经网络变得更加困难。所以添加BatchNorm层,在训练的时候BN层使用batch来估计数据的均值和方差,然后用均值和方差来标准化这个batch的数据,并且随着不同的batch经过网络,均值和方差都在做累计平均。在测试的时候就直接作为标准化的依据。

这样的方法也有可能导致降低神经网络的表示能力,因为某些层的全局最优的特征可能不是均值为0或者方差为1的。所以BN层也是能够进行学习每个特征维度的缩放gamma和平移beta的来避免这样的情况。

2.BN层前向传播

def batchnorm_forward(x, gamma, beta, bn_param):"""先进行标准化再进行平移缩放running_mean = momentum * running_mean + (1 - momentum) * sample_meanrunning_var = momentum * running_var + (1 - momentum) * sample_varInput:- x: (N, D) 输入的数据- gamma: (D,) 每个特征维度数据的缩放- beta: (D,) 每个特征维度数据的偏移- bn_param: 字典,有如下键值:- mode: 'train'/'test' 必须指定- eps: 一个常量为了维持数值稳定,保证不会除0- momentum: 动量- running_mean: (D,) 积累的均值- running_var: (D,) 积累的方差Returns:- out: (N,D)- cache: 反向传播时需要的数据"""mode = bn_param['mode']eps = bn_param.get('eps', 1e-5)momentum = bn_param.get('momentum', 0.9)N, D = x.shaperunning_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))out, cache = None, Noneif mode == 'train':sample_mean = np.mean(x, axis=0)sample_var = np.var(x, axis=0)# 先标准化x_hat = (x - sample_mean)/(np.sqrt(sample_var + eps))# 再做缩放偏移out = gamma * x_hat + betacache = (gamma, x, sample_mean, sample_var, eps, x_hat)running_mean = momentum * running_mean + (1-momuntum)*sample_meanrunning_var = momentum * running_var + (1-momentum)*sample_varelif mode == 'test':# 先标准化#x_hat = (x - running_mean)/(np.sqrt(running_var+eps))# 再做缩放偏移#out = gamma * x_hat + beta# 或者是下面的骚写法scale = gamma/(np.sqrt(running_var + eps))out = x*scale + (beta - running_mean*scale)else:raise ValueError('Invalid forward batchnorm mode "%s"' % mode)bn_param['running_mean'] = running_meanbn_param['running_var'] = running_varreturn out, cache

3.BN层反向传播

def batchnorm_barckward(out, cache):"""反向传播的简单写法,易于理解Inputs:- dout: (N,D) dloss/dout- cache: (gamma, x, sample_mean, sample_var, eps, x_hat)Returns:- dx: (N,D)- dgamma: (D,) 每个维度的缩放和平移参数不同- dbeta: (D,)"""dx, dgamma, dbeta = None, None, None# unpack cachegamma, x, u_b, sigma_squared_b, eps, x_hat = cacheN = x.shape[0]dx_1 = gamma * dout # dloss/dx_hat = dloss/dout * gamma (N, D)dx_2_b = np.sum((x - u_b) * dx_1, axis=0)dx_2_a = ((sigma_squared_b + eps)**-0.5)*dx_1dx_3_b = (-0.5) * ((sigma_squared_b + eps)**-1.5)*dx_2_bdx_4_b = dx_3_b * 1dx_5_b = np.ones_like(x)/N * dx_4_bdx_6_b = 2*(x-u_b)*dx_5_bdx_7_a = dx_6_b*1 + dx_2_a*1dx_7_b = dx_6_b*1 * dx_2_a*1dx_8_b = -1*np.sum(dx_7_b, axis=0)dx_9_b = np.ones_like(x)/N * dx_8_bdx_10 = dx_9_b + dx_7_adgamma = np.sum(x_hat * dout, axis=0)dbeta = np.sum(dout, axis=0)dx = dx_10return dx, dgamma, dbeta

下面是直接使用公式来计算:

def batchnorm_backward_alt(dout, cache):dx, dgamma, dbeta = None, None, None# unpack cachegamma, x, u_b, sigma_squared_b, eps, x_hat = cacheN = x.shape[0]dx_hat = dout * gammadvar = np.sum(dx_hat* (x - sample_mean) * -0.5 * np.power(sample_var + eps, -1.5), axis = 0)dmean = np.sum(dx_hat * -1 / np.sqrt(sample_var +eps), axis = 0) + dvar * np.mean(-2 * (x - sample_mean), axis =0)dx = 1 / np.sqrt(sample_var + eps) * dx_hat + dvar * 2.0 / N * (x-sample_mean) + 1.0 / N * dmeandgamma = np.sum(x_hat * dout, axis = 0)dbeta = np.sum(dout , axis = 0)return dx, dgamma, dbeta

4.BN有什么作用

  1. 对于不好的权重初始化有更高的鲁棒性,仍然能得到较好的效果。
  2. 能更好的避免过拟合。
  3. 解决梯度消失/爆炸问题,BN防止了前向传播的时候数值过大或者过小,这样就能让反向传播时梯度处于一个较好的区间内。

二、卷积神经网络中的BN

1.前向传播

def spatial_batchnorm_forward(x, gamma, beta, bn_param):"""利用普通神经网络的BN来实现卷积神经网络的BNInputs:- x: (N, C, H, W)- gamma: (C,)缩放系数- beta: (C,)平移系数- bn_param: 包含如下键的字典- mode: 'train'/'test'必须的键- eps: 数值稳定需要的一个较小的值- momentum: 一个常量,用来处理running mean和var的。如果momentum=0 那么之前不利用之前的均值和方差。momentum=1表示不利用现在的均值和方差,一般设置momentum=0.9- running_mean: (C,)- running_var: (C,)Returns:- out: (N, C, H, W)- cache: 反向传播需要的数据,这里直接使用了普通神经网络的cache"""N, C, H, W = x.shape# transpose之后(N, W, H, C) channel在这里就可以看成是特征temp_out, cache = batchnorm_forward(x.transpose(0, 3, 2, 1).reshape((N*H*W, C)), gamma, beta, bn_param)# 再恢复shapeout = temp_output.reshape(N, W, H, C).transpose(0, 3, 2, 1)return out, cache

2.反向传播

def spatial_batchnorm_backward(dout, cache):"""利用普通神经网络的BN反向传播实现卷积神经网络中的BN反向传播Inputs:- dout: (N, C, H, W) 反向传播回来的导数- cache: 前向传播时的中间数据Returns:- dx: (N, C, H, W)- dgamma: (C,) 缩放系数的导数- dbeta: (C,) 偏移系数的导数"""dx, dgamma, dbeta = None, None, NoneN, C, H, W = dout.shape# 利用普通神经网络的BN进行计算 (N*H*W, C)channel看成是特征维度dx_temp, dgamma, dbeta = batchnorm_backward_alt(dout.transpose(0, 3, 2, 1).reshape((N*H*W, C)), cache)# 将shape恢复dx = dx_temp.reshape(N, W, H, C).transpose(0, 3, 2, 1)return dx, dgamma, dbeta

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/682088.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

揭秘某电商公司最新面试流程

🏃‍♂️ 微信公众号: 朕在debugger© 版权: 本文由【朕在debugger】原创、需要转载请联系博主📕 如果文章对您有所帮助,欢迎关注、点赞、转发和订阅专栏! 记录近期某电商公司面试流程及问题,分为三面:…

Hive的相关概念——分区表、分桶表

目录 一、Hive分区表 1.1 分区表的概念 1.2 分区表的创建 1.3 分区表数据加载及查询 1.3.1 静态分区 1.3.2 动态分区 1.4 分区表的本质及使用 1.5 分区表的注意事项 1.6 多重分区表 二、Hive分桶表 2.1 分桶表的概念 2.2 分桶表的创建 2.3 分桶表的数据加载 2.4 …

【计算机网络】网际协议——互联网中的转发和编址

编址和转发是IP协议的重要组件 就像这个图所示,网络层有三个主要组件:IP协议,ICMP协议,路由选择协议IPV4 没有选项的时候是20字节 版本(号):4比特:规定了IP协议是4还是6首部长度&am…

作业2.14

指针练习 1、选择题 1.1、若有下面的变量定义,以下语句中合法的是(A)。 int i,a[10],*p; A) pa2; B) pa[5]; C) pa[2]2; D) p&(i2); 1.2、…

Servlet JSP-Eclipse安装配置Maven插件

Maven 是一款比较常用的 Java 开发拓展包,它相当于一个全自动 jar 包管理器,会导入用户开发时需要使用的相应 jar 包。使用 Maven 开发 Java 程序,可以极大提升开发者的开发效率。下面我就跟大家介绍一下如何在 Eclipse 里安装和配置 Maven 插…

医疗相关名词,医疗名词整理

1.系统类: HIS Hospital Information System,医院信息系统,在国际学术界已公认为新兴的医学信息学(Medical Informatics)的重要分支。美国该领域的著名教授Morris.Collen于1988年曾著文为医院信息系统下了如下定义:利用电子计算…

【安装指南】markdown神器之Typora下载、安装与无限使用详细教程

🌼一、概述 Typora是一款轻量级的Markdown编辑器,它提供了简洁的界面和直观的操作方式,专注于让用户更加专注于写作。Typora支持实时预览功能,用户在编辑Markdown文档时可以即时看到最终的样式效果,这有助于提高写作效…

Golang快速入门到实践学习笔记

Go学习笔记 1.基础 Go程序设计的一些规则 Go之所以会那么简洁,是因为它有一些默认的行为: 大写字母开头的变量是可导出的,也就是其它包可以读取 的,是公用变量;小写字母开头的就是不可导出的,是私有变量…

寒假学习记录11:grid布局

1. display:grid 2. grid-template-columns: 100px 100px 100px //指定每列的宽度 grid-template-rows: 100px 100px 100px //指定每行的宽度 3. column-gap: 24px //列间距 row-gap: 24px //行间距 gap: 24px //都设置 4.grid-template-areas用法 <!DO…

计算机组成原理 2 数据表示

机器数 研究机器内的数据表示&#xff0c;目的在于组织数据&#xff0c;方便计算机硬件直接使用。 需要考虑&#xff1a; 支持的数据类型&#xff1b; 能表示的数据精度&#xff1b; 是否有利于软件的移植 能表示的数据范围&#xff1b; 存储和处理的代价&#xff1b; ... 真值…

PHP开发日志 ━━ 深入理解三元操作与一般条件语句的不同

概况 三元运算符的功能与“if…else”流程语句一致。 在一般情况下&#xff0c;三元操作替换if条件语句可以精简代码&#xff0c;并且更为直观&#xff0c;但是在下面的情况中使用三元操作将会返回警告。 借图&#xff1a; 案例 比如原代码&#xff1a; class classA{publ…

DS:树及二叉树的相关概念

创作不易&#xff0c;兄弟们来波三连吧&#xff01;&#xff01; 一、树的概念及结构 1.1 树的概念 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09;个有限结点组成一个具有层次关系的集合。把它叫做树是因为它看起来像一棵倒挂的树&#xff0c…

Java并发基础:ConcurrentLinkedDeque全面解析!

内容概要 ConcurrentLinkedDeque类提供了线程安全的双端队列操作&#xff0c;支持高效的并发访问&#xff0c;因此在多线程环境下&#xff0c;可以放心地在队列的两端添加或移除元素&#xff0c;而不用担心数据的一致性问题。同时&#xff0c;它的内部实现采用了无锁算法&…

概率论-随机变量

更多AI技术入门知识与工具使用请看下面链接&#xff1a; https://student-api.iyincaishijiao.com/t/iNSVmUE8/

二叉树-------前,中,后序遍历 + 前,中,后序查找+删除节点 (java详解)

目录 提要&#xff1a; 创建一个简单的二叉树&#xff1a; 二叉树的前中后序遍历&#xff1a; 二叉树的前序遍历&#xff1a; 二叉树的中序遍历&#xff1a; 二叉树的后续遍历&#xff1a; 小结&#xff1a; 二叉树的前中后续查找&#xff1a; 二叉树的前序查找&#…

MySQL表的增删查改(基础)

新增&#xff08;Create) 1.全列插入 全列单行插入 insert into 表名 values(值&#xff0c;值……)&#xff1b; 也可以全列且多行插入 insert into 表名 values (值&#xff0c;值……)&#xff0c;(值&#xff0c;值……)……&#xff1b; 2.指定列插入 insert into 表…

【JAVA WEB】JavaScript--函数 作用域 对象

目录 函数 语法格式 示例 定义没有参数列表&#xff0c;也没有返回值的一个函数 定义一个有参数列表 &#xff0c;有返回值的函数 关于参数个数 函数表达式 作用域 作用域链 对象 基本概念 创建对象 1.使用 字面量 创建对象 2.使用new Object()创建对象 3.使…

【教程】MySQL数据库学习笔记(二)——数据类型(持续更新)

写在前面&#xff1a; 如果文章对你有帮助&#xff0c;记得点赞关注加收藏一波&#xff0c;利于以后需要的时候复习&#xff0c;多谢支持&#xff01; 【MySQL数据库学习】系列文章 第一章 《认识与环境搭建》 第二章 《数据类型》 文章目录 【MySQL数据库学习】系列文章一、整…

Ps:创建联系表

Ps菜单&#xff1a;文件/自动/联系表 II Automate/Contact sheet II Photoshop 的“联系表 II” Contact Sheet II命令为快速生成图像集合的预览和打印目录提供了一种高效的方法。 此命令可以通过自动化过程读取指定的图像文件&#xff0c;然后根据用户定义的参数&#xff08;如…

初识webpack(二)解析resolve、插件plugins、dev-server

目录 (一)webpack的解析(resolve) 1.resovle.alias 2.resolve.extensions 3.resolve.mainFiles (二) plugin插件 1.CleanWebpackPlugin 2.HtmlWebpackPlugin 3.DefinePlugin (三)webpack-dev-server 1.开启本地服务器 2.HMR模块热替换 3.devServer的更多配置项 (…