BatchNorm介绍:卷积神经网络中的BN

一、BN介绍

1.原理

在机器学习中让输入的数据之间相关性越少越好,最好输入的每个样本都是均值为0方差为1。在输入神经网络之前可以对数据进行处理让数据消除共线性,但是这样的话输入层的激活层看到的是一个分布良好的数据,但是较深的激活层看到的的分布就没那么完美了,分布将变化的很严重。这样会使得训练神经网络变得更加困难。所以添加BatchNorm层,在训练的时候BN层使用batch来估计数据的均值和方差,然后用均值和方差来标准化这个batch的数据,并且随着不同的batch经过网络,均值和方差都在做累计平均。在测试的时候就直接作为标准化的依据。

这样的方法也有可能导致降低神经网络的表示能力,因为某些层的全局最优的特征可能不是均值为0或者方差为1的。所以BN层也是能够进行学习每个特征维度的缩放gamma和平移beta的来避免这样的情况。

2.BN层前向传播

def batchnorm_forward(x, gamma, beta, bn_param):"""先进行标准化再进行平移缩放running_mean = momentum * running_mean + (1 - momentum) * sample_meanrunning_var = momentum * running_var + (1 - momentum) * sample_varInput:- x: (N, D) 输入的数据- gamma: (D,) 每个特征维度数据的缩放- beta: (D,) 每个特征维度数据的偏移- bn_param: 字典,有如下键值:- mode: 'train'/'test' 必须指定- eps: 一个常量为了维持数值稳定,保证不会除0- momentum: 动量- running_mean: (D,) 积累的均值- running_var: (D,) 积累的方差Returns:- out: (N,D)- cache: 反向传播时需要的数据"""mode = bn_param['mode']eps = bn_param.get('eps', 1e-5)momentum = bn_param.get('momentum', 0.9)N, D = x.shaperunning_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))out, cache = None, Noneif mode == 'train':sample_mean = np.mean(x, axis=0)sample_var = np.var(x, axis=0)# 先标准化x_hat = (x - sample_mean)/(np.sqrt(sample_var + eps))# 再做缩放偏移out = gamma * x_hat + betacache = (gamma, x, sample_mean, sample_var, eps, x_hat)running_mean = momentum * running_mean + (1-momuntum)*sample_meanrunning_var = momentum * running_var + (1-momentum)*sample_varelif mode == 'test':# 先标准化#x_hat = (x - running_mean)/(np.sqrt(running_var+eps))# 再做缩放偏移#out = gamma * x_hat + beta# 或者是下面的骚写法scale = gamma/(np.sqrt(running_var + eps))out = x*scale + (beta - running_mean*scale)else:raise ValueError('Invalid forward batchnorm mode "%s"' % mode)bn_param['running_mean'] = running_meanbn_param['running_var'] = running_varreturn out, cache

3.BN层反向传播

def batchnorm_barckward(out, cache):"""反向传播的简单写法,易于理解Inputs:- dout: (N,D) dloss/dout- cache: (gamma, x, sample_mean, sample_var, eps, x_hat)Returns:- dx: (N,D)- dgamma: (D,) 每个维度的缩放和平移参数不同- dbeta: (D,)"""dx, dgamma, dbeta = None, None, None# unpack cachegamma, x, u_b, sigma_squared_b, eps, x_hat = cacheN = x.shape[0]dx_1 = gamma * dout # dloss/dx_hat = dloss/dout * gamma (N, D)dx_2_b = np.sum((x - u_b) * dx_1, axis=0)dx_2_a = ((sigma_squared_b + eps)**-0.5)*dx_1dx_3_b = (-0.5) * ((sigma_squared_b + eps)**-1.5)*dx_2_bdx_4_b = dx_3_b * 1dx_5_b = np.ones_like(x)/N * dx_4_bdx_6_b = 2*(x-u_b)*dx_5_bdx_7_a = dx_6_b*1 + dx_2_a*1dx_7_b = dx_6_b*1 * dx_2_a*1dx_8_b = -1*np.sum(dx_7_b, axis=0)dx_9_b = np.ones_like(x)/N * dx_8_bdx_10 = dx_9_b + dx_7_adgamma = np.sum(x_hat * dout, axis=0)dbeta = np.sum(dout, axis=0)dx = dx_10return dx, dgamma, dbeta

下面是直接使用公式来计算:

def batchnorm_backward_alt(dout, cache):dx, dgamma, dbeta = None, None, None# unpack cachegamma, x, u_b, sigma_squared_b, eps, x_hat = cacheN = x.shape[0]dx_hat = dout * gammadvar = np.sum(dx_hat* (x - sample_mean) * -0.5 * np.power(sample_var + eps, -1.5), axis = 0)dmean = np.sum(dx_hat * -1 / np.sqrt(sample_var +eps), axis = 0) + dvar * np.mean(-2 * (x - sample_mean), axis =0)dx = 1 / np.sqrt(sample_var + eps) * dx_hat + dvar * 2.0 / N * (x-sample_mean) + 1.0 / N * dmeandgamma = np.sum(x_hat * dout, axis = 0)dbeta = np.sum(dout , axis = 0)return dx, dgamma, dbeta

4.BN有什么作用

  1. 对于不好的权重初始化有更高的鲁棒性,仍然能得到较好的效果。
  2. 能更好的避免过拟合。
  3. 解决梯度消失/爆炸问题,BN防止了前向传播的时候数值过大或者过小,这样就能让反向传播时梯度处于一个较好的区间内。

二、卷积神经网络中的BN

1.前向传播

def spatial_batchnorm_forward(x, gamma, beta, bn_param):"""利用普通神经网络的BN来实现卷积神经网络的BNInputs:- x: (N, C, H, W)- gamma: (C,)缩放系数- beta: (C,)平移系数- bn_param: 包含如下键的字典- mode: 'train'/'test'必须的键- eps: 数值稳定需要的一个较小的值- momentum: 一个常量,用来处理running mean和var的。如果momentum=0 那么之前不利用之前的均值和方差。momentum=1表示不利用现在的均值和方差,一般设置momentum=0.9- running_mean: (C,)- running_var: (C,)Returns:- out: (N, C, H, W)- cache: 反向传播需要的数据,这里直接使用了普通神经网络的cache"""N, C, H, W = x.shape# transpose之后(N, W, H, C) channel在这里就可以看成是特征temp_out, cache = batchnorm_forward(x.transpose(0, 3, 2, 1).reshape((N*H*W, C)), gamma, beta, bn_param)# 再恢复shapeout = temp_output.reshape(N, W, H, C).transpose(0, 3, 2, 1)return out, cache

2.反向传播

def spatial_batchnorm_backward(dout, cache):"""利用普通神经网络的BN反向传播实现卷积神经网络中的BN反向传播Inputs:- dout: (N, C, H, W) 反向传播回来的导数- cache: 前向传播时的中间数据Returns:- dx: (N, C, H, W)- dgamma: (C,) 缩放系数的导数- dbeta: (C,) 偏移系数的导数"""dx, dgamma, dbeta = None, None, NoneN, C, H, W = dout.shape# 利用普通神经网络的BN进行计算 (N*H*W, C)channel看成是特征维度dx_temp, dgamma, dbeta = batchnorm_backward_alt(dout.transpose(0, 3, 2, 1).reshape((N*H*W, C)), cache)# 将shape恢复dx = dx_temp.reshape(N, W, H, C).transpose(0, 3, 2, 1)return dx, dgamma, dbeta

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/680474.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MyBatis执行流程

MyBatis是一个流行的Java持久层框架,它封装了JDBC操作,使开发者可以通过XML或注解的方式映射SQL语句,并将POJO与数据库表之间进行映射。了解MyBatis的执行流程,可以帮助开发者更好地理解其内部工作机制,优化代码并解决…

企业微信自动推送机器人的应用与价值

随着科技的快速发展,企业微信自动推送机器人已经成为了企业数字化转型的重要工具。这种机器人可以自动推送消息、执行任务、提供服务,为企业带来了许多便利。本文将探讨企业微信自动推送机器人的应用和价值。 一、企业微信自动推送机器人的应用 企业微信…

无人机飞行控制系统功能,多旋翼飞行控制系统概述

飞行控制系统存在的意义 行控制系统通过高效的控制算法内核,能够精准地感应并计算出飞行器的飞行姿态等数据,再通过主控制单元实现精准定位悬停和自主平稳飞行。 在没有飞行控制系统的情况下,有很多的专业飞手经过长期艰苦的练习&#xff0…

「数据结构」串

串的定义和实现 串的定义 串: 即字符串,零个或多个字符组成的有限序列串的长度:串中字符的个数n空串:n0时的串子串:串中任意多个连续的字符组成的子序列主串:包含子串的串字符在主串中的位置:字符在串中的…

【Docker进阶】镜像制作-用Dockerfile制作镜像(一)

进阶一 docker镜像制作 文章目录 进阶一 docker镜像制作用dockerfile制作镜像dockerfile是什么dockerfile格式为什么需要dockerfileDockerfile指令集合FROMMAINTAINERLABELCOPYENVWORKDIR 用dockerfile制作镜像 用快照制作镜像的缺陷: 黑盒不可重复臃肿 docker…

嵌入式大厂面试题(1)—— CVTE

从本篇开始将会更新历年来各个公司的面试题与面经,题目来自于网上各个平台以及博主自己遇到的,如果大家有所帮助,帮忙点点赞和关注吧! 岗位:嵌入式软件工程师。 面试时间:20分钟。 面试 1 、简历中写了做过…

Kafka 入门笔记

课程地址 概述 定义 Kafka 是一个分布式的基于发布/订阅模式的消息队列(MQ) 发布/订阅:消息的发布者不会将消息直接发送给特定的订阅者,而是将发布的消息分为不同的类别,订阅者只接受感兴趣的消息 消息队列 消息队…

HCIA-Datacom实验指导手册:4.3 实验三:网络地址转换配置实验

HCIA-Datacom实验指导手册:4.3 实验三:网络地址转换配置实验 一、实验介绍:二、 思考题与附加内容 一、实验介绍: NAT的作用: 1、很大程度提高网络安全性。 2、控制内外网网络联通性问题。 特点: 1&#…

JDK 11 vs JDK 8:探索Java的新特性和改进

随着技术的不断进步,Java开发工具包(JDK)也在不断演变,为开发者带来更高效、更安全的编程体验。在这篇文章中,我们将重点探讨JDK11相较于JDK 8所引入的一些新特性和改进,以便您能够更好地了解Java的最新发展…

leetcode:买卖股票最佳时机二

思路: 使用贪心算法:局部最优是将买卖过程中产生的正数进行相加,进而使得最后结果最大(全局最优)。 price [7,1,5,10,3,6,4] -6,4,5,-7,3,-2 正数相加就得到了最大 代码实现: 1.循环中下标从1开始 …

大数据的基础探索之大数据时代

前言:大数据已经是大势所趋,在这个网络时代能够不断地整合资源的人本身也是一种能力拥有者,在这个时代,如果一个人可以掌握数据分析工具,利用好云计算的能力,对于自己的个人而言来说都是一个极其重要的参与…

【PyTorch】张量(Tensor)的生成

PyTorch深度学习总结 第一章 Pytorch中张量(Tensor)的生成 文章目录 PyTorch深度学习总结一、什么是PyTorch?二、张量(Tensor)1、张量的数据类型2、张量生成和信息获取 总结 一、什么是PyTorch? PyTorch是一个开源的深度学习框架,基于Python…

20240212请问如何将B站下载的软字幕转换成为SRT格式?

20240212请问如何将B站下载的软字幕转换成为SRT格式? 2024/2/12 12:47 百度搜索:字幕 json 转 srt json srt https://blog.csdn.net/a_wh_white/article/details/120687363?share_token2640663e-f468-4737-9b55-73c808f5dcf0 https://blog.csdn.net/a_w…

上位机图像处理和嵌入式模块部署(利用python开发软件)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 开发windows和linux软件的时候,大家一般都是习惯于用c/c语言进行开发,但是目前来说很多的开发板都是支持python语言开发的。…

RK3588平台开发系列讲解(视频篇)RKMedia 数据流向

文章目录 一、 获取RKMedia模块通道中的数据二、RKMedia的数据源和接收者三、模块通道绑定API调用 沉淀、分享、成长,让自己和他人都能有所收获!😄 📢RKMedia是RK提供的一种多媒体处理方案,可实现音视频捕获、音视频输…

服务治理中间件-Eureka

目录 简介 搭建Eureka服务 注册服务到Eureka 简介 Eureka是Spring团队开发的服务治理中间件,可以轻松在项目中,实现服务的注册与发现,相比于阿里巴巴的Nacos、Apache基金会的Zookeeper,更加契合Spring项目,缺点就是…

前端vue 数字 字符串 丢失精度问题

1.问题 后端返回的数据 是这样的 一个字符串类型的数据 前端要想显示这个 肯定需要使用Json.parse() 转换一下 但是 目前有一个问题 转换的确可以 showId:1206381711026823172 有一个这样的字段 转换了以后 发现 字段成了1206381711026823200 精度直接丢了 原本的数据…

MySQL监控Innodb信息

Innodb监控 Innodb由于支持事务操作,是mysql中使用最多的存储引擎,所以如何监控Innodb存储引擎以进行性能优化是在使用mysql过程中遇到最多的,那么如何进行监控呢? show engine -- 显示innodb存储引擎状态的统计和配置信息show en…

MogaNet实战:使用MogaNet实现图像分类任务(一)

文章目录 摘要安装包安装timm 数据增强Cutout和MixupEMA项目结构计算mean和std生成数据集 摘要 论文:https://arxiv.org/pdf/2211.03295.pdf 作者多阶博弈论交互这一全新视角探索了现代卷积神经网络的表示能力。这种交互反映了不同尺度上下文中变量间的相互作用效…

C语言函数指针实现函数参数化

之前学习了基本的函数指针;函数指针有多种用途;下面看一下函数参数化; 函数参数化是指通过函数指针将函数的某些行为参数化。这样可以在调用函数时动态地指定函数的行为。 新建一个单文档工程;下述增加的函数声明加到视类cpp文件的头部,函数体加到视类cpp文件的尾部,在…