【一起深度学习——批量规范化】

批量规范化

  • 1、为啥要批量规范化呢?
  • 2、如何批量规范化呢?
  • 3、实现批量归一化。
  • 4、定义BatchNorm层:
  • 5、定义神经网络:
  • 6、开始训练:

1、为啥要批量规范化呢?

1、可持续加速深层网络的收敛速度。
2、对于深层网络来说非常复杂,容易导致过拟合。

2、如何批量规范化呢?

均值u = (∑x)/B B是样本个数
方差o^2 = (∑(x - u)^2)/B + c (c是小噪声) 为啥要设置这个c呢,避免分母除0
BN = gamma * (x - u)/o + beta

3、实现批量归一化。

代码如下:

# moving_mean :均值, moving_var 方差, eps:就是上边那个 c(小噪声),避免太小。 momentum : 用于更新moving_mean 和moving_var
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):#用于检测当前是训练模式还是预测模式if not torch.is_grad_enabled():#如果是在训练模式下,直接使用传入的移动平均所得到的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:# X.shape表示呢,X这个张量的形状维度大小。#例如:全连接层:(样本数,输入特征) 而,卷积层:(批量大小,输出通道,高度,宽度)assert len(X.shape) in (2,4)if len(X.shape) == 2:#使用全连接层的情况,计算特征维上的均值和方差。mean = X.mean(dim=0)  #按列来计算特征值的均值var = ((X - mean) **2 ).mean(dim=0) #均值方差else:#对于卷积层来说,(批量大小,通道,高度,宽度)# 理解一下这里的(dim=(0,2,3)),对于上边的dim =0,相当于压缩列方向。# 那么看dim =(0,2,3),相当于压缩,批量方向,高度方向,宽度方向,最终会只剩下通道方向,所以结果是:1*n,1*1mean = X.mean(dim=(0,2,3),keepdim =True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + betareturn Y,moving_mean.data,moving_var.data

4、定义BatchNorm层:

class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:# 全连接层shape = (1, num_features)else:# 卷积层,高度和宽度都设置为1 ,是为了使用广播机制。shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

5、定义神经网络:

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))

6、开始训练:

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/7971.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

罗德与施瓦茨 SMC100A信号发生器9kHz至3.2 GHz

罗德与施瓦茨 SMC100A信号发生器,9 kHz - 3.2 GHz 罗德与施瓦茨 SMC100A 以极具吸引力的价格提供出色的信号质量。它覆盖的频率范围为 9 kHz 至 1.1 GHz 或 3.2 GHz。输出功率为典型值。> 17 dBm。所有重要功能(AM/FM/φM/脉冲调制)均已集…

代码随想录算法训练营第六十天| 647. 回文子串,516.最长回文子序列,动态规划总结篇

题目与题解 参考资料:动态规划总结篇 647. 回文子串 题目链接:647. 回文子串 代码随想录题解:647. 回文子串 视频讲解:动态规划,字符串性质决定了DP数组的定义 | LeetCode:647.回文子串_哔哩哔哩_bilibili …

【Camera2完整流程分析三】从Log角度跟踪分析原生Camera2相机分析相机拍照流程

一,概述 通过阅读本篇文件后,你会了解到原生Camera2的从app层到framework,再到camera service层,最后到Hal3层的完整代码架构和代码流程。学习本篇文章后可以对拍照take picture快速上手,并能通过log跟踪来快速分析和定位问题,具体知识点的如下: 1,Camera2在app层摄像…

【busybox记录】【shell指令】unexpand

目录 内容来源: 【GUN】【unexpand】指令介绍 【busybox】【unexpand】指令介绍 【linux】【unexpand】指令介绍 使用示例: 空格转化成制表符 - 默认输出 空格转化成制表符 - 转换所有的空格 空格转化成制表符 - 指定制表位 常用组合指令&#…

项目提交到空的git仓库流程

流程: # 初始化 Git 仓库 git init # 如果遇到 "detected dubious ownership" 的错误,可以添加 safe.directory 配置以解决 git config --global --add safe.directory T:/project/heima-leadnews # 将当前目录下的所有文件添加到 Git 暂存区…

构造照亮世界——快速沃尔什变换 (FWT)

博客园 我的博客 快速沃尔什变换解决的卷积问题 快速沃尔什变换(FWT)是解决这样一类卷积问题: ci∑ij⊙kajbkc_i\sum_{ij\odot k}a_jb_k ci​ij⊙k∑​aj​bk​其中,⊙\odot⊙ 是位运算的一种。举个例子,给定数列 a,…

MATLAB基础应用精讲-【数模应用】非参数检验(附python、MATLAB和R语言代码实现)

目录 几个相关概念 算法原理 什么是非参数检验 何时使用非参数检验

腾讯安全客户端(电脑管家部门)一面

上来介绍部门,之后自我介绍 说了是个喜欢每天都学点新东西的人,然后平常也会在课余时间之外去做点项目方面的学习,比如Web项目做出来就是因为兴趣。喜欢结构性的东西,有一门课叫电路电子学一次考试是专业第二。其他也都还可以&am…

微信小程序交互增强:实现上拉加载、下拉刷新与一键返回顶部【代码示例】

微信小程序交互增强:实现上拉加载、下拉刷新与一键返回顶部【代码示例】 基础概念实现步骤与代码示例1. 下拉刷新2. 上拉加载更多3. 返回顶部 性能优化与安全考虑结语与讨论 在微信小程序的开发过程中,提供流畅的用户体验至关重要,其中上拉加…

小米手机miui14 android chrome如何取消网页自动打开app

搜索媒体打开应用 选择你要阻止打开的app,以github为例 取消勾选打开支持的链接。 参考:https://www.reddit.com/r/chrome/s/JBsGkZDkRZ

创建禁止操作区域并且添加水印

css 设置 : 引用换成自己就好 .overlay {z-index: 1000;cursor: none; /*设置为不可点击*/user-select: none; /*设置为不可选择*/contenteditable: false; /*设置为不可编辑*/draggable: false; /*设置为不可拖动*/position: absolute;top: 0;left: 0;width: 100…

git bash退出vim编译模式

解决方法: 1.按esc键(回到命令模式) 此时是没有分号让我们在后面输入命令的 2.按shift键: 3.再输入:wq,并按enter键 此时我们发现又回到git bash窗口 希望对大家有所帮助!

JavaScript学习—网络请求

在 JavaScript 中,XMLHttpRequest 对象是一个用于与服务器交换数据的接口,允许在后台与服务器进行异步通信。这使得网页可以在不中断用户交互的情况下从服务器请求数据。 方法 open():用于设置请求的类型、URL和是否异步处理请求。 var xhr …

数据库-脏读

脏读(Dirty Read)是数据库并发控制中的一个概念,指的是一个事务读取了另一个尚未提交的事务的修改。由于另一个事务的修改可能最终会被撤销(即发生回滚操作),因此,当前事务读取到的数据可能是“…

一览函数式编程

文章目录 一、 什么是函数式编程1.1 编程范式1.1.1 命令式编程(Imperative Programming)范式1.1.2 声明式编程(Declarative Programming)范式1.1.3 函数式编程(Functional Programming)范式1.1.4 面向对象编程(Object-Oriented Programming)范式1.1.5 元编程(Metaprogramming)范…

ThinkPHP5.1 创建控制器类

在ThinkPHP中,控制器是MVC模式中的核心组件之一,负责接收用户请求并处理相应的业务逻辑。在本篇技术博客中,我们将深入探讨ThinkPHP5.1中的控制器操作,包括创建控制器、路由绑定、请求参数获取等方面的知识点。 1.创建控制器 在T…

(1day)致远M3 log 敏感信息泄露漏洞(Session)复现

前言 系统学习web漏洞挖掘以及项目实战也有一段时间了,发现在漏洞挖掘过程中难免会碰到一些历史漏洞,来帮助自己或是提高自己挖洞和及时发现漏洞效率,于是开始创建这个专栏,对第一时间发现的1day以及历史漏洞进行复现,来让自己更加熟悉漏洞类型以及历史漏洞,方便自己在后续的项…

商家制作微信小程序有什么好处?微信小程序的制作有哪些步骤和流程

微信小程序全面指南 微信小程序是微信生态系统中一项革命性的功能,为希望与庞大的微信用户群体互动的企业提供了独特的融合便捷性和功能性的体验。本全面指南深入探讨了微信小程序的世界,强调了其重要性、工作原理以及实际用例,特别是针对企…

开发组合php+mysql 人才招聘小程序源码搭建 招聘平台系统源码+详细图文搭建部署教程

随着互联网的快速发展,传统的招聘方式已经不能满足企业和求职者的需求。为了提高招聘效率,降低招聘成本,越来越多的人开始关注人才招聘小程序、在线招聘平台。分享一个人才招聘小程序源码及搭建,让招聘更加高效便捷。系统是运营级…

windows安装ElasticSearch以及踩坑

1.下载 elasticsearch地址:Past Releases of Elastic Stack Software | Elastichttps://www.elastic.co/cn/downloads/past-releases#elasticsearch IK分析器地址:infinilabs/analysis-ik: 🚌 The IK Analysis plugin integrates Lucene IK…