李沐--动手学深度学习 批量规范化

1.理论

2.从零开始实现批量规范化

import torch
from torch import nn
from d2l import torch as d2l
from torch.utils.hooks import RemovableHandle
#从零开始实现批量规范化
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):#通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():#如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2,4)if len(X.shape) == 2:#使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim = 0)var = ((X - mean) ** 2).mean(dim=0)else:#使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。#这里需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0,2,3),keepdim = True)var = ((X - mean) ** 2).mean(dim = (0,2,3),keepdim = True)#训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)#更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta #缩放和移位return Y,moving_mean.data,moving_var.data#创建一个正确的BatchNorm层。这个层将保持适当的参数:拉伸gamma和偏移beta,这两个参数将在训练过程中更新。层将保存均值和方差的移动平均值,以便在模型预测期间随后使用。
class BatchNorm(nn.Module):#num_features: 完全连接层的输出数量或卷积层的输出通道数#num_dims:2表示完全连接层,4表示卷积层def __init__(self,num_features,num_dims):super().__init__()if num_dims == 2:shape = (1,num_features)else:shape = (1,num_features,1,1)#参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))#非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self,X):#如果X不在内存上,将moving_mean 和 moving_var复制到X所在的显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)#把要用的moving_mean挪到要运行的设备上self.moving_var = self.moving_var.to(X.device)#保存更新过的moving_mean 和 moving_varY,self.moving_mean,self.moving_var = batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return Y#更好理解如何应用BatchNorm,下面我们将其应用于LeNet模型
#批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。
net = nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),nn.Linear(84,10))#在Fashion-MNIST数据集上训练网络。这个代码与第一次训练LeNet时几乎完全相同,主要区别在于学习率大得多。
lr,num_epochs,batch_size = 1.0,10,256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
d2l.plt.show()
print(net[1].gamma.reshape((-1,)),net[1].beta.reshape((-1)))

3.调用框架实现

import torch
from torch import nn
from d2l import torch as d2l
from torch.utils.hooks import RemovableHandle#调用框架实现
#除了使用刚刚定义的BatchNorm,也可以直接使用深度学习框架中定义的BatchNorm
net = nn.Sequential(nn.Conv2d(1,6,kernel_size=5),nn.BatchNorm2d(6),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.BatchNorm2d(16),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(256,120),nn.BatchNorm1d(120),nn.Sigmoid(),nn.Linear(120,84),nn.BatchNorm1d(84),nn.Sigmoid(),nn.Linear(84,10))#使用相同超参数来训练模型。请注意,通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现
lr,num_epochs,batch_size = 1.0,10,256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/878005.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java-使用HashMap压缩重复数据量以减少堆内存溢出的概率

使用 HashMap 压缩重复数据量以减少堆内存溢出的概率 为了减少堆内存溢出的概率,可以通过使用 HashMap 来压缩重复的数据量。这里我们可以通过以 下步骤实现: 创建一个 HashMap:用于存储数据及其出现次数。 遍历原始数据:将数据放入 HashMap 中,并统计每个数据出现的次…

PyCharm汉化:简单一步到胃!PyCharm怎么设置中文简体

最近在弄python的项目 一起加油哦 步骤: PyCharm的汉化可以通过两种主要方法完成: 方法一:通过PyCharm内置的插件市场安装中文语言包 1. 打开PyCharm,点击File -> Settings(在Mac上是PyCharm -> Preferences…

java一键生成数据库说明文档html格式

要验收项目了,要写数据库文档,一大堆表太费劲了,直接生成一个吧,本来想用个别人的轮子,网上看了几个,感觉效果不怎么好,自己动手写一个吧。抽空再把字典表补充进去就OK了 先看效果: …

Session Cookie Jwt Token常见web授权

基于分布式系统、同公司内、同一个 redis 作为存储,这个是目前主要的用法,去找开源框架都是这个逻辑;对外开放等使用参考 OAuth 2.0 能够标识出用户是谁,安全性相对高一些,就是好的方案。 Cookie Set 和 Get&#x…

Python3:多行文本内容转换为标准的cURL请求参数值

背景 在最近的工作中,经常需要处理一些接口请求的参数,参数来源形式很多,可能是Excel、知识库文档等,有些数据形式比较复杂,比如多行或者包含很多不同的字符,示例如下: **客服质检分析指引** …

【精选】分享9款AI毕业论文生成初稿题目网站

在当今学术研究领域,AI技术的应用日益广泛,尤其是在学术论文的撰写过程中。AI论文生成器的出现,极大地简化了学术写作流程,提高了写作效率。以下是9款推荐的AI毕业论文生成初稿的网站,它们各有特色,能够满足…

MFC工控项目实例之七点击下拉菜单弹出对话框

承接专栏《MFC工控项目实例之六CFile添加菜单栏》 1、在SEAL_PRESSUREDlg.h文件中添加代码 class CSEAL_PRESSUREDlg : public CDialog { ...afx_msg void OnTypeManage(); ... } 2、在SEAL_PRESSUREDlg.cpp文件中添加代码 BEGIN_MESSAGE_MAP(CSEAL_PRESSUREDlg, CDialog)//…

MySQL的源码安装及基本部署(基于RHEL7.9)

这里源码安装mysql的5.7.44版本 一、源码安装 1.下载并解压mysql , 进入目录: wget https://downloads.mysql.com/archives/get/p/23/file/mysql-boost-5.7.44.tar.gz tar xf mysql-boost-5.7.44.tar.gz cd mysql-5.7.44/ 2.准备好mysql编译安装依赖: yum install cmake g…

数据结构:用栈实现队列(232)LeetCode

请你仅使用两个栈实现先入先出队列。队列应当支持一般队列支持的所有操作(push、pop、peek、empty): 实现 MyQueue 类: void push(int x) 将元素 x 推到队列的末尾 int pop() 从队列的开头移除并返回元素 int peek() 返回队列开…

Python爬虫——简单网页抓取(实战案例)小白篇

Python 爬虫是一种强大的工具,用于从网页中提取数据。这里,我将通过一个简单的实战案例来展示如何使用 Python 和一些流行的库(如 requests 和 BeautifulSoup)来抓取网页数据。 实战案例:抓取一个新闻网站的头条新闻标…

Windows上传Linux文件行尾符转换

Windows上传Linux文件行尾符转换 1、Windows与Linux文件行尾符2、Windows与Linux文件格式转换 1、Windows与Linux文件行尾符 众所周知,Windows、Mac与Linux三种系统的文件行尾符不同,其中 Windows文件行尾符(\r\n): L…

使用kafka改造分布式事务

文章目录 1、kafka确保消息不丢失?1.1、生产者端确保消息不丢失1.2、kafka服务端确保消息不丢失1.3、消费者确保正确无误的消费 2、生产者发送消息 KafkaService3、UserInfoServiceImpl -> login()4、service-account - > AccountListener.java 1、kafka确保消…

day31-测试之性能测试工具JMeter的功能概要、元件作用域和执行顺序

目录 一、JMeter的功能概要 1.1.文件目录介绍 1).bin目录 2).docs目录 3).printable_docs目录 4).lib目录 1.2.基本配置 1).汉化 2).主题修改 1.3.基本使用流程 二、JMeter元件作用域和执行顺序 2.1.名称解释 2.2.基本元件 2.3.元件作用域 1).核心 2).提示 3).作用域的原则 2.…

常用PHP JS MySQL 常用方法记录

常用PHP JS MySQL 常用方法记录 MySQL 1)查询 Select 1.1)FROM_UNIXTIME 根据创建时间 时间戳 筛选 WHEREFROM_UNIXTIME(kl.created_at) BETWEEN 2024-08-01 00:00:01 AND 2024-08-08 23:59:59 1.2)DATE_FORMAT 格式化时间戳 DATE_FOR…

Redis 实现哨兵模式

目录 1 哨兵模式介绍 1.1 什么是哨兵模式 1.2 sentinel中的三个定时任务 2 配置哨兵 2.1 实验环境 2.2 实现哨兵的三条参数: 2.3 修改配置文件 2.3.1 MASTER 2.3.2 SLAVE 2.4 将 sentinel 进行备份 2.5 开启哨兵模式 2.6 故障模拟 3 在整个架构中可能会出现的问题 …

go中 panicrecoverdefer机制

go的defer机制-CSDN博客 常见panic场景 数组或切片越界,例如 s : make([]int, 3); fmt.Println(s[5]) 会引发 panic: runtime error: index out of range空指针调用,例如 var p *Person; fmt.Println(p.Name) 会引发 panic: runtime error: invalid m…

Android Init Language

Android Init Language 安卓初始化语言,是一种用于配置和管理 Android 系统服务的专用脚本语言。主要用于编写 .rc 文件(比如我们熟知的init.rc文件),这些文件在系统启动时由 init 进程读取和执行,从而设置和启动系统服…

Mako 模板语言

Mako 模板语言 Mako的哲学:Python is great scripting language ,don’t reinvent the wheel, your template can handle it !, api非常简单, ####入门 Template类是创建模板和渲染模板的核心类 from mako.template import Template mytemplate Template("hello world&…

网络通信tcp

一、udp案例 二、基于tcp: tcp //c/s tcp 客户端: 1.建立连接 socket bind connect 2.通信过程 read write close tcp服务器: 1.建立连接 socket bind listen accept 2.通信过程 read write close connect函数 int connect(int sockfd, con…

Git克隆仓库太大导致拉不下来的解决方法 fatal: fetch-pack: invalid index-pack output

一般这种问题是因为某个文件/某个文件夹/某些文件夹过大导致整个项目超过1G了导致的 试过其他教程里的设置depth为1,也改过git的postBuffer,都不管用 最后还是靠克隆指定文件夹这种方式成功把项目拉下来 1. Git Bash 输入命令 git clone --filterblob:none --sparse 项目路径…