动手学深度学习(Pytorch版)代码实践 -卷积神经网络-28批量规范化

28批量规范化

"""可持续加速深层网络的收敛速度"""
import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as pltdef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):"""实现一个具有张量的批量规范化层。"""# 如果启用了梯度计算,torch.is_grad_enabled() 返回 True;否则返回 False。if not torch.is_grad_enabled():# torch.no_grad() 是一个上下文管理器,用于临时禁用梯度计算# torch.enable_grad() 是一个上下文管理器,用于在禁用梯度计算的上下文中重新启用梯度计算。X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0) # 计算张量 X 沿着第 0 维的平均值# 维度 0 代表样本数量,即沿着每个特征计算所有样本的平均值。var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * var# gamma 和 beta 的更新是通过反向传播和优化器自动完成的Y = gamma * X_hat + beta # 缩放和移位return Y, moving_mean.data, moving_var.dataclass BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1# 经过归一化处理后的数据均值接近于零。因此,将滑动均值初始化为0,是对数据初始均值的一种合理假设。self.moving_mean = torch.zeros(shape)# 方差表示数据的离散程度。将滑动方差初始化为1,意味着假设数据的初始方差为1,# 即数据分布接近标准正态分布。这样初始化可以避免初始阶段的数值不稳定。self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在GPU上                              if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y#使用批量规范化层的 LeNet
net = nn.Sequential(nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2BatchNorm(num_features=6, num_dims=4),nn.ReLU(), # 激活函数nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5BatchNorm(num_features=16, num_dims=4),nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2nn.Flatten(), # 展平层:将多维输入展平为1维nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120BatchNorm(num_features=120, num_dims=2),nn.ReLU(),nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84BatchNorm(num_features=84, num_dims=2),nn.ReLU(), nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size)
# lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# plt.show()# loss 0.200, train acc 0.925, test acc 0.812
# 34957.3 examples/sec on cuda:0# loss 0.189, train acc 0.928, test acc 0.894
# 33471.2 examples/sec on cuda:0#简明实现
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.ReLU(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.ReLU(),nn.Linear(84, 10)
)
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
plt.show()# nn.Sigmoid()
# loss 0.263, train acc 0.902, test acc 0.833
# 46935.0 examples/sec on cuda:0# nn.ReLU()
# loss 0.224, train acc 0.914, test acc 0.874
# 44479.2 examples/sec on cuda:0
"""
通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。
"""

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/32571.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Swift 中的动态数组

Swift 的 Array 类型是一种强大而灵活的集合类型,可以根据需要自动扩展或缩减其容量。 动态数组的基本概念 Swift 中的数组是基于动态数组(dynamic array)的概念实现的。动态数组能够根据需要自动调整其容量,以容纳新增的元素&a…

Benchmarking Panoptic Scene Graph Generation (PSG), ECCV‘22 场景图生成,利用PSG数据集

2080-ti显卡复现 源代码地址 Jingkang50/OpenPSG: Benchmarking Panoptic Scene Graph Generation (PSG), ECCV22 (github.com) 安装 pytorch 1.7版本 cuda10.1 按照readme的做法安装 我安装的过程如下图所示,这个截图是到了pip install openmim这一步 下一步 下一步 这一步…

什么是FIFO管理单元?(First-In-First-Out)

FIFO(First-In-First-Out,先进先出)管理单元是一种广泛用于数据处理和存储系统中的机制,其核心理念是确保最早进入系统的数据最早被处理或移出。这种管理方法类似于排队的方式,最早进入队列的项目会最先得到服务。 FIF…

C语言 | Leetcode C语言题解之第167题两数之和II-输入有序数组

题目&#xff1a; 题解&#xff1a; int* twoSum(int* numbers, int numbersSize, int target, int* returnSize) {int* ret (int*)malloc(sizeof(int) * 2);*returnSize 2;int low 0, high numbersSize - 1;while (low < high) {int sum numbers[low] numbers[high]…

redis如何做内存优化

1、数据结构的优化 1、使用数据结构的最小存储形式。例如&#xff0c;如果你需要存储一组唯一的用户ID&#xff0c;你可以将这些信息合并到一个大的哈希表中&#xff0c;而不是为每个用户创建单独的哈希表&#xff0c;以减少固定开销。 2、使用整数编码。例如&#xff0c;存储…

如何设置MySQL远程访问权限?

MySQL是一种流行的关系型数据库管理系统&#xff0c;它广泛应用于各种Web应用程序和数据驱动的应用中。在默认情况下&#xff0c;MySQL只允许本地访问&#xff0c;为了能够从远程服务器或客户端访问MySQL数据库&#xff0c;我们需要进行一些额外的设置和配置。 安装和配置MySQ…

在阿里云使用Docker部署MySQL服务,并且通过IDEA进行连接

阿里云使用Docker部署MySQL服务&#xff0c;并且通过IDEA进行连接 这里演示如何使用阿里云来进行MySQL的部署&#xff0c;系统使用的是Linux系统 (Ubuntu)。 为什么使用Docker? 首先是因为它的可移植性可以在任何有Docker环境的系统上运行应用&#xff0c;避免了在不通操作系…

jetpack compose的@Preview和自定义主题

1.Preview Preview可以在 Android Studio 的预览窗口中实时查看和调试 UI 组件。 基本使用 import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.material.MaterialTheme import androidx.compose.material.Surface import androidx.compose.ma…

【html】用html+css实现银行的账户信息表格

我们先来看一看某银行的账户信息表格 我们自己也可以实现类似的效果 效果图: 大家可以看到&#xff0c;其实效果差不多 接下来看看我们实现的代码 源码&#xff1a; <!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8"&…

机械师硬盘数据清空怎么办?机械师硬盘数据清空怎么恢复

机械师硬盘数据清空怎么恢复&#xff1f;随着数字化时代的到来&#xff0c;数据已成为我们生活和工作中不可或缺的一部分。然而&#xff0c;硬盘数据的意外清空往往会给我们带来极大的困扰。本文将探讨在机械师硬盘数据清空后&#xff0c;我们应该如何快速有效地恢复数据。 图片…

深入理解Java中的JPA与Hibernate

深入理解Java中的JPA与Hibernate 大家好&#xff0c;我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天我们将深入探讨Java中的JPA&#xff08;Java Persistence API&#xff09;…

二本毕业,我是如何成为BAT-安卓开发工程师?

1.对基础原理不断挖掘 进入公司&#xff0c;我的职位是Linux应用开发工程师&#xff0c;做App网络传输模块&#xff0c;本质上就是把本地的数据通过socket传输到服务端。用到的技术是C语言&#xff0c;网络编程&#xff0c;多线程编程。 那时是最痛苦的几个月&#xff0c;因为…

如何在Java中实现线程安全的单例模式

如何在Java中实现线程安全的单例模式 大家好&#xff0c;我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天我们将探讨在Java中如何实现线程安全的单例模式。单例模式是一种常见的…

JDBC是什么?

JDBC&#xff08;Java Database Connectivity&#xff09;称为Java数据库连接&#xff0c;它是一种用于数据库访问的应用程序API&#xff0c;由一组用Java语言编写的类和接口组成。以下是关于JDBC的详细说明&#xff1a; 1. 定义 JDBC提供了一种基准&#xff0c;据此可以构建…

[FreeRTOS 功能应用] 互斥访问与回环队列 功能应用

文章目录 一、基础知识点二、代码讲解三、结果演示四、代码下载 一、基础知识点 [FreeRTOS 基础知识] 互斥访问与回环队列 概念 [FreeRTOS 内部实现] 互斥访问与回环队列 [FreeRTOS 内部实现] 创建任务 xTaskCreate函数解析 本实验是基于STM32F103开发移植FreeRTOS实时操作系…

解决WebStorm中不显示npm任务面板

鼠标右键项目的package.json文件&#xff0c;然后点击show npm scripts选项。 然后npm工具窗口就显示了&#xff1a;

02--MySQL数据库概述

目录 第10章 子查询 10.1 SELECT的SELECT中嵌套子查询 10.2 SELECT的WHERE或HAVING中嵌套子查询 10.3 SELECT中的EXISTS型子查询 10.4 SELECT的FROM中嵌套子查询 第11章 MySQL支持的数据类型 11.1 数值类型:包括整数和小数 1、整数类型 2、bit类型 3、小数类型 11.2…

Typescript: declear

问: const book: string 这样就可以声明而且赋值为什么还用declear去分成好几步骤走呢? 同时即使不赋值只需要使用const book: string;难道不也行吗? 为什么要加上一个declear呢? 回答: 在 TypeScript 中&#xff0c;声明变量和使用 declare 声明类型信息是两个不同的概念…

【Python系列】探索 NumPy 中的 mean 函数:计算平均值的利器

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Selenium WebDriver - 动作API

本文翻译整理自&#xff1a;https://www.selenium.dev/documentation/webdriver/actions_api/ 文章目录 一、行动建设者二、暂停三、释放所有操作四、键盘动作1、钥匙2、钥匙放下3、钥匙打开4、发送钥匙活性元素指定元素 5、复制和粘贴 五、鼠标动作1、点击并按住2、点击并释放…