李沐--动手学深度学习 批量规范化

1.理论

2.从零开始实现批量规范化

import torch
from torch import nn
from d2l import torch as d2l
from torch.utils.hooks import RemovableHandle
#从零开始实现批量规范化
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):#通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():#如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2,4)if len(X.shape) == 2:#使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim = 0)var = ((X - mean) ** 2).mean(dim=0)else:#使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。#这里需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0,2,3),keepdim = True)var = ((X - mean) ** 2).mean(dim = (0,2,3),keepdim = True)#训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)#更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta #缩放和移位return Y,moving_mean.data,moving_var.data#创建一个正确的BatchNorm层。这个层将保持适当的参数:拉伸gamma和偏移beta,这两个参数将在训练过程中更新。层将保存均值和方差的移动平均值,以便在模型预测期间随后使用。
class BatchNorm(nn.Module):#num_features: 完全连接层的输出数量或卷积层的输出通道数#num_dims:2表示完全连接层,4表示卷积层def __init__(self,num_features,num_dims):super().__init__()if num_dims == 2:shape = (1,num_features)else:shape = (1,num_features,1,1)#参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))#非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self,X):#如果X不在内存上,将moving_mean 和 moving_var复制到X所在的显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)#把要用的moving_mean挪到要运行的设备上self.moving_var = self.moving_var.to(X.device)#保存更新过的moving_mean 和 moving_varY,self.moving_mean,self.moving_var = batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return Y#更好理解如何应用BatchNorm,下面我们将其应用于LeNet模型
#批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。
net = nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),nn.Linear(84,10))#在Fashion-MNIST数据集上训练网络。这个代码与第一次训练LeNet时几乎完全相同,主要区别在于学习率大得多。
lr,num_epochs,batch_size = 1.0,10,256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
d2l.plt.show()
print(net[1].gamma.reshape((-1,)),net[1].beta.reshape((-1)))

3.调用框架实现

import torch
from torch import nn
from d2l import torch as d2l
from torch.utils.hooks import RemovableHandle#调用框架实现
#除了使用刚刚定义的BatchNorm,也可以直接使用深度学习框架中定义的BatchNorm
net = nn.Sequential(nn.Conv2d(1,6,kernel_size=5),nn.BatchNorm2d(6),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.BatchNorm2d(16),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(256,120),nn.BatchNorm1d(120),nn.Sigmoid(),nn.Linear(120,84),nn.BatchNorm1d(84),nn.Sigmoid(),nn.Linear(84,10))#使用相同超参数来训练模型。请注意,通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现
lr,num_epochs,batch_size = 1.0,10,256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/878005.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyCharm汉化:简单一步到胃!PyCharm怎么设置中文简体

最近在弄python的项目 一起加油哦 步骤: PyCharm的汉化可以通过两种主要方法完成: 方法一:通过PyCharm内置的插件市场安装中文语言包 1. 打开PyCharm,点击File -> Settings(在Mac上是PyCharm -> Preferences…

java一键生成数据库说明文档html格式

要验收项目了,要写数据库文档,一大堆表太费劲了,直接生成一个吧,本来想用个别人的轮子,网上看了几个,感觉效果不怎么好,自己动手写一个吧。抽空再把字典表补充进去就OK了 先看效果: …

Python3:多行文本内容转换为标准的cURL请求参数值

背景 在最近的工作中,经常需要处理一些接口请求的参数,参数来源形式很多,可能是Excel、知识库文档等,有些数据形式比较复杂,比如多行或者包含很多不同的字符,示例如下: **客服质检分析指引** …

【精选】分享9款AI毕业论文生成初稿题目网站

在当今学术研究领域,AI技术的应用日益广泛,尤其是在学术论文的撰写过程中。AI论文生成器的出现,极大地简化了学术写作流程,提高了写作效率。以下是9款推荐的AI毕业论文生成初稿的网站,它们各有特色,能够满足…

MFC工控项目实例之七点击下拉菜单弹出对话框

承接专栏《MFC工控项目实例之六CFile添加菜单栏》 1、在SEAL_PRESSUREDlg.h文件中添加代码 class CSEAL_PRESSUREDlg : public CDialog { ...afx_msg void OnTypeManage(); ... } 2、在SEAL_PRESSUREDlg.cpp文件中添加代码 BEGIN_MESSAGE_MAP(CSEAL_PRESSUREDlg, CDialog)//…

MySQL的源码安装及基本部署(基于RHEL7.9)

这里源码安装mysql的5.7.44版本 一、源码安装 1.下载并解压mysql , 进入目录: wget https://downloads.mysql.com/archives/get/p/23/file/mysql-boost-5.7.44.tar.gz tar xf mysql-boost-5.7.44.tar.gz cd mysql-5.7.44/ 2.准备好mysql编译安装依赖: yum install cmake g…

Python爬虫——简单网页抓取(实战案例)小白篇

Python 爬虫是一种强大的工具,用于从网页中提取数据。这里,我将通过一个简单的实战案例来展示如何使用 Python 和一些流行的库(如 requests 和 BeautifulSoup)来抓取网页数据。 实战案例:抓取一个新闻网站的头条新闻标…

Windows上传Linux文件行尾符转换

Windows上传Linux文件行尾符转换 1、Windows与Linux文件行尾符2、Windows与Linux文件格式转换 1、Windows与Linux文件行尾符 众所周知,Windows、Mac与Linux三种系统的文件行尾符不同,其中 Windows文件行尾符(\r\n): L…

使用kafka改造分布式事务

文章目录 1、kafka确保消息不丢失?1.1、生产者端确保消息不丢失1.2、kafka服务端确保消息不丢失1.3、消费者确保正确无误的消费 2、生产者发送消息 KafkaService3、UserInfoServiceImpl -> login()4、service-account - > AccountListener.java 1、kafka确保消…

day31-测试之性能测试工具JMeter的功能概要、元件作用域和执行顺序

目录 一、JMeter的功能概要 1.1.文件目录介绍 1).bin目录 2).docs目录 3).printable_docs目录 4).lib目录 1.2.基本配置 1).汉化 2).主题修改 1.3.基本使用流程 二、JMeter元件作用域和执行顺序 2.1.名称解释 2.2.基本元件 2.3.元件作用域 1).核心 2).提示 3).作用域的原则 2.…

Redis 实现哨兵模式

目录 1 哨兵模式介绍 1.1 什么是哨兵模式 1.2 sentinel中的三个定时任务 2 配置哨兵 2.1 实验环境 2.2 实现哨兵的三条参数: 2.3 修改配置文件 2.3.1 MASTER 2.3.2 SLAVE 2.4 将 sentinel 进行备份 2.5 开启哨兵模式 2.6 故障模拟 3 在整个架构中可能会出现的问题 …

go中 panicrecoverdefer机制

go的defer机制-CSDN博客 常见panic场景 数组或切片越界,例如 s : make([]int, 3); fmt.Println(s[5]) 会引发 panic: runtime error: index out of range空指针调用,例如 var p *Person; fmt.Println(p.Name) 会引发 panic: runtime error: invalid m…

网络通信tcp

一、udp案例 二、基于tcp: tcp //c/s tcp 客户端: 1.建立连接 socket bind connect 2.通信过程 read write close tcp服务器: 1.建立连接 socket bind listen accept 2.通信过程 read write close connect函数 int connect(int sockfd, con…

Git克隆仓库太大导致拉不下来的解决方法 fatal: fetch-pack: invalid index-pack output

一般这种问题是因为某个文件/某个文件夹/某些文件夹过大导致整个项目超过1G了导致的 试过其他教程里的设置depth为1,也改过git的postBuffer,都不管用 最后还是靠克隆指定文件夹这种方式成功把项目拉下来 1. Git Bash 输入命令 git clone --filterblob:none --sparse 项目路径…

探索Unity3D URP后处理在UI控件Image上的应用

探索Unity3D URP后处理在UI控件Image上的应用 前言初识URP配置后处理效果将后处理应用于UI控件方法一:自定义Shader方法二:RenderTexture的使用 实践操作步骤一:创建RenderTexture步骤二:UI渲染至RenderTexture步骤三:…

视频如何转gif?分享这几款软件!

在这个快节奏、高创意的互联网时代,动图(GIF)以其独特的魅力成为了社交媒体、聊天软件中的宠儿。它们不仅能瞬间抓住眼球,还能让信息传递更加生动有趣。然而,你是否曾为如何将精彩瞬间从视频中精准截取并转换成GIF而苦…

​北斗终端:无人驾驶领域的导航新星

一、北斗终端在无人驾驶领域的应用 北斗终端,作为我国自主研发的北斗卫星导航系统的重要组成部分,其在无人驾驶领域中的应用正逐步显现其独特魅力。北斗系统的高精度、高可靠性和良好的抗干扰性能,为无人驾驶车辆提供了精确的定位和导航服务…

关于超长字符串/文本对应的数据从excel导入到PL/SQL中的尝试

问题: 1.字符串太长 2.str绑定之的结尾null缺失 将csv文件导入到PL/SQL表中存在的一些问题 1.本来我是需要将exceL上的几十条数据导入到PL/SQL数据库的一张表中,结果我花了许多时间 去导入。 想想一般情况下也就几十条数据,直接复制粘贴就…

C语言-有两个磁盘文件A和B,各存放一行字母,今要求把这两个文件的信息合并(按字母顺序排列),输出到一个新文件C中去-深度代码解析

🌏个人博客:尹蓝锐的博客 1、题目要求 有两个磁盘文件A和B,各存放一行字母,今要求把这两个文件的信息合并(按字母顺序排列),输出到一个新文件C中去 2、准备工作 问题1:为什么不需要…

chrome打印dom节点不显示节点信息

正常直接console dom节点 代码改成 var parser new DOMParser(); var docDom parser.parseFromString(testHtml, text/html); console.log(docDom) let htmlHeader ref< HTMLElement | null>(null) let htmlBoby ref< HTMLElement | null>(null) htmlHeader.v…